assembly.KWinnersTakeAll

class assembly.KWinnersTakeAll(k_active=50)[source]

Bases: torch.nn.modules.module.Module

K-winners-take-all activation function.

Parameters
k_activeint, optional

k, the number of active (winner) neurons within an output layer. Default: 50

Methods

forward(x)

The forward pass of kWTA.

Attributes

forward(x)[source]

The forward pass of kWTA.

Parameters
x(N,) torch.Tensor

An input vector.

Returns
y(N,) torch.Tensor

The output vector y = kwta(x) with exactly k active neurons.