Source code for assembly.areas

from collections import OrderedDict
from string import ascii_uppercase

import torch
import torch.nn as nn
from abc import ABC, abstractmethod

from assembly.constants import K_ACTIVE
from assembly.samplers import sample_bernoulli
from mighty.utils.common import find_layers

__all__ = [
    "KWinnersTakeAll",
    "AreaRNNHebb",
    "AreaRNNWillshaw",
    "AreaStack",
    "AreaSequential"
]


[docs]class KWinnersTakeAll(nn.Module): """ K-winners-take-all activation function. Parameters ---------- k_active : int, optional `k`, the number of active (winner) neurons within an output layer. Default: 50 """ def __init__(self, k_active=K_ACTIVE): super().__init__() self.k_active = k_active
[docs] def forward(self, x): """ The forward pass of kWTA. Parameters ---------- x : (N,) torch.Tensor An input vector. Returns ------- y : (N,) torch.Tensor The output vector ``y = kwta(x)`` with exactly :attr:`k` active neurons. """ winners = x.topk(k=self.k_active, sorted=False).indices y = torch.zeros_like(x) y[winners] = 1 return y
def extra_repr(self): return f"k_active={self.k_active}"
class AreaInterface(nn.Module, ABC): def recall(self, xs_stim): """ A forward pass without latent activations. Parameters ---------- xs_stim : torch.Tensor or tuple of torch.Tensor Input vectors from the incoming areas. Returns ------- y_out : torch.Tensor The output vector. """ mode = self.training self.eval() y_out = self(xs_stim) self.train(mode) return y_out def complete_from_input(self, xs_partial, y_partial=None): """ Complete the pattern from the partial input. Nothing more than a simple forward pass without updating the weights. Parameters ---------- xs_partial : torch.Tensor or tuple of torch.Tensor Partially active input vectors from the incoming areas. y_partial : torch.Tensor or None, optional The stored latent (hidden activations) vector from the previous step with partial activations. Default: None Returns ------- y_out : torch.Tensor The output vector. """ mode = self.training self.eval() y_out = self(xs_partial, y_latent=y_partial) self.train(mode) return y_out def memory_used(self): r""" Computes the used memory bits as :math:`\frac{||W||_0}{\text{size}(W)}` Returns ------- dict A dictionary with used memory for each parameter (weight matrix). """ memory_used = {} for name, param in self.named_parameters(): memory_used[name] = param.norm(p=0) / param.nelement() return memory_used def normalize_weights(self): """ Normalize the pre-synaptic weights sum to ``1.0``. Without the normalization, all inputs converge to the same output vector determined by the lateral weights because the sum ``w_xy @ x + w_lat @ y`` favors the second element. Normalization of the feedforward and lateral weights makes ``w_xy @ x`` and ``w_lat @ y`` of the same magnitude. """ for module in find_layers(self, layer_class=AreaRNN): for weight in module.parameters(recurse=False): # input and recurrent weights module._normalize_weight(weight) assert torch.isfinite(weight).all() class AreaRNN(AreaInterface, ABC): def __init__(self, *in_features: int, out_features, p_synapse=0.05, recurrent_coef=1., sampler=sample_bernoulli): super().__init__() self.in_features = in_features self.out_features = out_features self.recurrent_coef = recurrent_coef self.weights_input = [] for parent_id, neurons_in in enumerate(in_features): weight_in = nn.Parameter( sampler(out_features, neurons_in, proba=p_synapse), requires_grad=False) self.register_parameter(name=f"weight_input{parent_id}", param=weight_in) self.weights_input.append(weight_in) self.weight_recurrent = nn.Parameter( sampler(out_features, out_features, proba=p_synapse), requires_grad=False) self.kwta = KWinnersTakeAll() self.normalize_weights() def forward(self, xs_stim, y_latent=None): """ The forward pass :eq:`forward`. Parameters ---------- xs_stim : torch.Tensor or tuple of torch.Tensor Input vectors from the incoming areas. y_latent : torch.Tensor or None, optional The stored latent (hidden activations) vector from the previous step. Default: None Returns ------- y_out : torch.Tensor The output vector. """ if isinstance(xs_stim, torch.Tensor): xs_stim = [xs_stim] if xs_stim is None or all(x is None for x in xs_stim): return None assert len(xs_stim) == len(self.weights_input) y_out = torch.zeros(self.out_features) for x, w_in in zip(xs_stim, self.weights_input): if x is not None: y_out += w_in.matmul(x) if y_latent is not None: # y_out += alpha * W_rec @ y_latent y_out.addmv_(mat=self.weight_recurrent, vec=y_latent, alpha=self.recurrent_coef) y_out = self.kwta(y_out) if self.training: for x, w_in in zip(xs_stim, self.weights_input): if x is not None: self.update_weight(w_in, x=x, y=y_out) if y_latent is not None: self.update_weight(self.weight_recurrent, x=y_latent, y=y_out) return y_out def update_weight(self, weight, x, y): """ Update the weight, given the activations. Parameters ---------- weight : torch.Tensor The weight to update. x, y : torch.Tensor Input and output vectors. """ pass @abstractmethod def _normalize_weight(self, weight): """ Normalize the pre-synaptic weight sum to ``1.0``. Parameters ---------- weight : torch.Tensor A weight matrix. """ pass def complete_pattern(self, y_partial): """ Complete the pattern using the recurrent connections only. Parameters ---------- y_partial : torch.Tensor A partially activated latent vector. Returns ------- y : torch.Tensor The reconstructed vector `y`. """ y = self.weight_recurrent.matmul(y_partial) y = self.kwta(y) return y def extra_repr(self): return f"in_features: {self.in_features}, " \ f"out_features: {self.out_features}, " \ f"recurrent_coef={self.recurrent_coef}"
[docs]class AreaRNNHebb(AreaRNN): r""" A Hebbian-learning recurrent neural network with one or more incoming input layers and only one output layer. The update rule, if :math:`x_j` and :math:`y_i` neurons fired: * additive: .. math:: W_{ij} = W_{ij} + \beta :label: update-additive * multiplicative: .. math:: W_{ij} = W_{ij} * (1 + \beta) :label: update-multiplicative After each epoch, many repetitions of the same input trial, the weights are normalized to have ``1.0`` in its pre-synaptic sum for each neuron. Parameters ---------- *in_features The sizes of input vectors from incoming areas. out_features : int The size of the output layer. p_synapse : float, optional The initial probability of recurrent and afferent synaptic connectivity. Default: 0.05 recurrent_coef : float, optional The recurrent coefficient :math:`\alpha` in :eq:`forward`. Default: 1 learning_rate : float, optional The plasticity coefficient :math:`\beta` in :eq:`update-additive` and :eq:`update-multiplicative`. Default: 0.1 sampler : {sample_bernoulli, sample_uniform_masked}, optional Weights initialization function to call: either Bernoulli or uniform. Default: sample_bernoulli update : {'additive', 'multiplicative'}, optional The weight update learning rule. Default: 'multiplicative' Notes ----- `'additive'` update learning rule allows new weights to grow, as opposed to `'multiplicative'`. """ def __init__(self, *in_features: int, out_features, p_synapse=0.05, recurrent_coef=1., learning_rate=0.1, sampler=sample_bernoulli, update='multiplicative'): super().__init__(*in_features, out_features=out_features, p_synapse=p_synapse, recurrent_coef=recurrent_coef, sampler=sampler) self.learning_rate = learning_rate if update == 'additive': self.update_weight = self.update_weight_additive elif update == 'multiplicative': self.update_weight = self.update_weight_multiplicative else: raise ValueError(f"Invalid update rule: '{update}'") def update_weight_additive(self, weight, x, y): # w_ij = w_ij + learning_rate, if x_j and y_i fired: # w_ij = w_ij + learning_rate * x_j * y_i weight.addr_(y, x, alpha=self.learning_rate) def update_weight_multiplicative(self, weight, x, y): # w_ij = w_ij * (1 + learning_rate), if x_j and y_i fired: # w_ij = w_ij * (1 + learning_rate * x_j * y_i) weight.mul_(1 + self.learning_rate * y.unsqueeze(1) * x.unsqueeze(0)) def _normalize_weight(self, weight): presum = weight.sum(dim=1, keepdim=True) presum[presum == 0] = 1 # all elements in a row are zeros weight /= presum def extra_repr(self): update = self.update_weight.__name__.lstrip('update_weight_') return f"{super().extra_repr()}, update='{update}'"
[docs]class AreaRNNWillshaw(AreaRNN): r""" Non-Holographic Associative Memory Area [1]_: a recurrent neural network with one or more incoming input layers and only one output layer. The weights are sparse and binary. The update rule, if :math:`x_j` and :math:`y_i` neurons fired: .. math:: W_{ij} = 1 :label: update-will This update rule is the simplest possible update rule that requires neither the learning rate nor the weight normalization, compared to :class:`AreaRNNHebb`. Parameters ---------- *in_features The sizes of input vectors from incoming areas. out_features : int The size of the output layer. p_synapse : float, optional The initial probability of recurrent and afferent synaptic connectivity. Default: 0.05 recurrent_coef : float, optional The recurrent coefficient :math:`\alpha` in :eq:`forward`. Default: 1.0 References ---------- .. [1] Willshaw, D. J., Buneman, O. P., & Longuet-Higgins, H. C. (1969). Non-holographic associative memory. Nature, 222(5197), 960-962. """ def __init__(self, *in_features: int, out_features, p_synapse=0.05, recurrent_coef=1, **ignored): super().__init__(*in_features, out_features=out_features, p_synapse=p_synapse, recurrent_coef=recurrent_coef, sampler=sample_bernoulli)
[docs] def update_weight(self, weight, x, y): # w_ij = 1, if x_j and y_i fired, and 0 otherwise weight.addr_(y, x) weight.clamp_max_(1)
def _normalize_weight(self, weight): # the weights are already binary at the update stage pass
[docs]class AreaStack(nn.Sequential, AreaInterface): """ Vertically stacked areas. The output activations will be linearly summed. Parameters ---------- *areas Vertically stacked :class:`AreaRNN`. """ def __init__(self, *areas: AreaRNN): areas_named = OrderedDict({ f"{letter}": area for letter, area in zip(ascii_uppercase, areas) }) nn.Sequential.__init__(self, areas_named)
[docs] def forward(self, xs_stim, y_latent=None): if xs_stim is None: xs_stim = [None] * len(self) assert len(xs_stim) == len(self) if y_latent is None: y_latent = [None] * len(xs_stim) y_out = [area(x, y_latent=yl) for area, x, yl in zip(self, xs_stim, y_latent)] return y_out
[docs]class AreaSequential(nn.Sequential, AreaInterface): """ A sequence of areas. The output of one area is fed into the next area. """
[docs] def forward(self, xs_stim, y_latent=None): if xs_stim is None: xs_stim = [None] * len(self) assert len(xs_stim) == len(self) y_out = xs_stim if y_latent is None: y_latent = [None] * len(self) y_intermediate = [] # hidden activations of the intermediate layers for module, yl in zip(self, y_latent): y_out = module(y_out, y_latent=yl) y_intermediate.append(y_out) return y_out, y_intermediate
[docs] def recall(self, xs_stim): y_out, y_intermediate = super().recall(xs_stim) return y_out