Source code for assembly.monitor

import math
import time
from collections import defaultdict

import numpy as np
import torch
from mighty.monitor.batch_timer import timer
from mighty.monitor.viz import VisdomMighty
from mighty.utils.common import find_named_layers
from torchvision import transforms
from torchvision.transforms import InterpolationMode

from assembly.areas import AreaInterface, AreaRNN
from assembly.constants import K_ACTIVE, N_NEURONS
from assembly.graph import GraphArea, graphviz_notify_if_not_installed
from assembly.utils import factors_root


[docs]def expected_random_overlap(n, k): """ Computes the expected overlap of binomial sampled random vectors. Parameters ---------- n : int The total number of neurons. k : int The number of active neurons. Returns ------- float The expected random overlap. """ p_overlap = [math.comb(k, x) * math.comb(n - k, k - x) / math.comb(n, k) for x in range(k + 1)] overlap_expected = np.multiply(p_overlap, range(k + 1)).sum() return overlap_expected
[docs]def pairwise_similarity(tensors): """ Computes the pairwise similarity overlap of the tensors. Parameters ---------- tensors : list of torch.Tensor A list of binary vectors. Each entry can be either a single vector tensor (one incoming area) or a tuple of tensors (multiple incoming areas). Returns ------- similarity : float The pairwise :math:`L_{0/1}` similarity from 0 to 1. """ tensors = [t for t in tensors if t is not None] if len(tensors) <= 1: return np.nan if not isinstance(tensors[0], torch.Tensor): # multiple incoming areas sim_areas = list(map(pairwise_similarity, zip(*tensors))) sim_areas = np.nanmean(sim_areas) return sim_areas else: tensors = torch.stack(tensors) similarity = tensors.matmul(tensors.t()) n_elements = len(tensors) ii, jj = torch.triu_indices(row=n_elements, col=n_elements, offset=1) similarity = similarity[ii, jj].mean() similarity /= K_ACTIVE return similarity
class VisdomBuffered(VisdomMighty): def __init__(self, legend_labels, env="main"): super().__init__(env=env) self.close() # clear previous plots self._legend_labels = tuple(legend_labels) legend_labels = self.legend_labels('k-active') self.opts = {'recall': dict( xlabel='Epoch', ylabel='overlap', legend=list(legend_labels), title='recall (y_pred, y_learned)', ), 'convergence': dict( xlabel='Epoch', ylabel='overlap', legend=list(legend_labels), title='convergence (y, y_prev)', ), 'support': dict( xlabel='Epoch', ylabel='support', legend=list(legend_labels), title='support size across epoch trials' ), 'similarity': dict( xlabel='Epoch', ylabel='similarity', legend=list(self._legend_labels), title="Learned assemblies pairwise similarity overlap", markers=True, markersize=8, )} self.data_epoch = defaultdict(list) def legend_labels(self, *prepend): return prepend + self._legend_labels def send_buffered(self): legend_labels = self.legend_labels('k-active') for win in self.data_epoch.keys(): data_epoch_finished = {name: np.nan for name in legend_labels} self.buffer(data=data_epoch_finished, win=win) n_trials = len(self.data_epoch[win]) y = np.full((n_trials, len(legend_labels)), fill_value=np.nan) times, data = zip(*self.data_epoch[win]) for data_dict, yi in zip(data, y): for label, val in data_dict.items(): yi[legend_labels.index(label)] = val y[:, 0] = K_ACTIVE times = np.tile(times, reps=(len(legend_labels), 1)).T self.line(Y=y, X=times, win=win, opts=self.opts[win], update='append') self.data_epoch.clear() def buffer(self, data, win): self.data_epoch[win].append((timer.epoch_progress(), data))
[docs]class Monitor: """ Monitor the training progress. Parameters ---------- model : AreaInterface A NN model, consisting of one or more areas. """ def __init__(self, model, env_suffix=''): self.model = model self.ys_output = dict() self.ys_previous = None self.ys_learned = defaultdict(list) self.support = {} self.handles = [] self.module_name = dict() for name, layer in find_named_layers(model, layer_class=AreaRNN): self.module_name[layer] = f"{name}-{layer.__class__.__name__}" \ .lstrip('-') handle = layer.register_forward_hook(self._forward_hook) self.handles.append(handle) env_name = f"{time.strftime('%Y.%m.%d')} " \ f"{model.__class__.__name__} assemblies" if env_suffix: env_name = f"{env_name} {env_suffix}" self.viz = VisdomBuffered(legend_labels=self.module_name.values(), env=env_name) self.log_expected_random_overlap() self.log_model() self.draw_model() self.reset() def _forward_hook(self, module, input, output): name = self.module_name[module] self.ys_output[name] = output
[docs] def remove_handles(self): """ Remove the hooks that has been used to track intermediate layers output. """ for handle in self.handles: handle.remove()
def reset(self): self.ys_output.clear() self.ys_previous = None self.ys_learned.clear() self.support.clear() def _names_active(self, ys_output=True): assert sorted(self.ys_previous.keys()) == sorted( self.module_name.values()) if ys_output: assert sorted(self.ys_previous.keys()) == sorted( self.ys_output.keys()) names_active = [] for name in self.module_name.values(): active = self.ys_previous[name] is not None if ys_output: active &= self.ys_output[name] is not None if active: names_active.append(name) return tuple(names_active) def _update_convergence(self): if self.ys_previous is None: return overlaps = {} for name in self._names_active(): overlaps[name] = self.ys_output[name].matmul( self.ys_previous[name]).item() self.viz.buffer(data=overlaps, win='convergence') def _update_support(self): support = {} for name, y in self.ys_output.items(): if y is None: continue if name not in self.support: self.support[name] = y.clone().bool() self.support[name] |= y.bool() support[name] = self.support[name].sum() self.viz.buffer(data=support, win='support') def _update_recall(self, x_samples_learned): assert len(self.ys_output) == 0 ys_learned = {} names_active = self._names_active(ys_output=False) for name in names_active: ys_learned[name] = self.ys_learned[name] + [self.ys_previous[name]] assert len(ys_learned[name]) == len(x_samples_learned) recall = defaultdict(float) for i, x in enumerate(x_samples_learned): # ys_output will be populated via the forward hook y_ignored = self.model.recall(x) for name in names_active: y_predicted = self.ys_output[name] y_learned = ys_learned[name][i] n_total = len(ys_learned[name]) recall[name] += (y_predicted * y_learned).sum() / n_total self.ys_output.clear() self.viz.buffer(data=recall, win='recall')
[docs] def update_memory_used(self): """ Plot the weights histograms. """ names, values = list(zip(*self.model.memory_used().items())) self.viz.line_update(y=values, opts=dict( xlabel='Epoch', title=r"Memory used (L0 norm)", legend=list(names), ))
[docs] def assembly_similarity(self): """ Computes the similarity of learned assemblies. Returns ------- similarity : dict A dict with pairwise assembly similarity. """ similarity = dict() for name, y_learned in self.ys_learned.items(): y_learned = [y for y in y_learned if y is not None] if len(y_learned) > 0: similarity[name] = pairwise_similarity(y_learned) return similarity
[docs] def trial_finished(self, x_samples_learned, step): """ A sample is being learned callback. Parameters ---------- x_samples_learned : list A list of learned input vectors to recall. Each entry can be either a single vector tensor (one incoming area) or a tuple of tensors (multiple incoming areas). step : int The batch ID. """ self._update_convergence() self._update_support() self.ys_previous = self.ys_output.copy() self.ys_output.clear() if step != 0: # skip the first batch since the recall is exactly 'k' self._update_recall(x_samples_learned)
[docs] def epoch_finished(self): """ A sample is learned callback. """ # ys_output is already cleared up for name, y_final in self.ys_previous.items(): self.ys_learned[name].append(y_final) self.ys_previous = None self.support.clear() self.update_memory_used() self.update_weight_histogram() self.viz.send_buffered()
[docs] def log_expected_random_overlap(self, n=N_NEURONS, k=K_ACTIVE): r""" Log the expected random overlap between samples, drawn from `Binomial(n, k)` distribution. Parameters ---------- n : int, optional The number of neurons in a layer. Default: 1000 k : int, optional The number of active neurons in a layer. Default: 50 """ self.viz.log(f"Expected random overlap (n={n}, k={k}): " f"{expected_random_overlap(n=n, k=k):.3f}")
[docs] def update_assembly_similarity(self, input_similarity=None, log=False): r""" Plot the :math:`L_{0/1}` similarity of the projected (learned) assemblies. The similarity of two binary vectors :math:`\bm{x}` and :math:`\bm{y}` of size `n` that have `k` active neurons is computed as their dot product, divided by `k`: .. math:: \frac{\bm{x} \cdot \bm{y}}{k} Parameters ---------- input_similarity : float or None, optional If given, plot the input vectors similarity as well. Default: None log : bool, optional If True, log the similarities as text. Default: False """ assembly_similarity = self.assembly_similarity() if input_similarity is not None: assembly_similarity['input'] = input_similarity legend = self.viz.opts['similarity']['legend'] similarity_nans = [assembly_similarity.get(name, np.nan) for name in legend] x = np.stack([[timer.epoch] * len(legend), similarity_nans], axis=1) self.viz.scatter(X=x, Y=range(1, len(legend) + 1), opts=self.viz.opts['similarity'], win='similarity', update='append') if log: lines = ["Learned assemblies intra-similarity:"] for name, similarity in assembly_similarity.items(): lines.append(f"--{name}: {similarity:.3f}") text = '<br>'.join(lines) self.viz.log(text=text)
[docs] def plot_associated_activations(self, ys_traces, ys_all_active, learned=False): """ Plot the output layer activations for each single trace (active input area). Overlapping neurons are shown in green. Parameters ---------- ys_traces : (S, P, N) torch.Tensor Stacked individual traces (output layer `C` activations). Axes: 1) `S` - the number of (learned) samples; 2) `P` - the number of incoming areas (parents); 3) `N` - the number of output neurons. ys_all_active : (S, N) torch.Tensor The output activations of layer `C` with all input areas active (have input). learned : bool, optional Have the model areas been associated (True) or not? This flag is used in the plots title only. Default: False """ learned_str = "after" if learned else "before" rgb_green = torch.tensor([0, 1, 0], dtype=torch.float) # (S, P+1, N) images = torch.cat([ys_traces, ys_all_active.unsqueeze(dim=1)], dim=1) row_size = images.shape[1] # (S, P+1, 3, N) images = images.unsqueeze(dim=2).repeat_interleave(3, dim=2) mask_overlap = ys_traces.prod(dim=1) # (S, N) ii_overlap, jj_overlap = mask_overlap.nonzero(as_tuple=True) images[ii_overlap, :-1, :, jj_overlap] = rgb_green mask_overlap *= ys_all_active ii_overlap_all, jj_overlap_all = mask_overlap.nonzero(as_tuple=True) images[ii_overlap_all, -1, :, jj_overlap_all] = rgb_green # (S * (P+1), 3, H, N/H) images = images.view(-1, 3, *factors_root(images.shape[-1])) resize = transforms.Resize(size=128, interpolation=InterpolationMode.NEAREST) images = resize(images) # (S * (P+1), 3, ~128, ~128) title = f"{learned_str}: A only | B only | all" self.viz.images(images, win=f"activations {learned_str}", nrow=row_size, opts=dict(title=title))
[docs] def log_model(self, space='-'): """ Log the :attr:`model`. Parameters ---------- space : str, optional A space substitution to correctly parse HTML later on. Default: '-' """ lines = ["Area model:"] for line in repr(self.model).splitlines(): n_spaces = len(line) - len(line.lstrip()) line = space * n_spaces + line lines.append(line) lines = '<br>'.join(lines) self.viz.log(lines)
[docs] @graphviz_notify_if_not_installed def draw_model(self, sample=None): """ Draw the model graph. Parameters ---------- sample : torch.Tensor or None, optional Input sample. Default: None """ graph = GraphArea() svg = graph.draw_model(self.model, sample=sample) self.viz.svg(svgstr=svg, win='graph')
[docs] def update_weight_histogram(self): """ Plot the model weights histogram. """ for name, param in self.model.named_parameters(): self.viz.histogram(X=param.data.view(-1), win=name, opts=dict( xlabel='Weight values', ylabel='count', title=f"{name} histogram", ytype='log', ))