import torch
from functools import partial
from tqdm import tqdm
from assembly.areas import *
from assembly.constants import K_ACTIVE, N_NEURONS
from assembly.monitor import Monitor, pairwise_similarity
from assembly.samplers import *
from mighty.monitor.batch_timer import timer
[docs]class Simulator:
"""
Train and simulate Computation with Assemblies.
Parameters
----------
model : AreaInterface
A NN model, consisting of one or more areas.
epoch_size : int, optional
Defines the number of simulations to run for each trial sample.
Each trial sample represent a complete epoch.
Default: 10
"""
def __init__(self, model, epoch_size=10, env_suffix=''):
self.model = model
self.monitor = Monitor(model, env_suffix=env_suffix)
self.epoch_size = epoch_size
timer.init(epoch_size)
[docs] def simulate(self, x_samples):
"""
Train and simulate the :attr:`model` with the input `x_samples` data.
Parameters
----------
x_samples : list
The input stimuli samples list. Each entry can be either a vector
tensor or, in case of :class:`AreaStack`, a pair of tensors.
"""
self.monitor.reset()
for sample_count, x in enumerate(tqdm(x_samples, desc="Projecting"),
start=1):
y_prev = None # inhibit the area
for step in range(self.epoch_size):
y = self.model(x, y_latent=y_prev)
if isinstance(self.model, AreaSequential):
y, y_prev = y
else:
y_prev = y
timer.tick()
self.monitor.trial_finished(x_samples[:sample_count],
step=step)
self.model.normalize_weights()
self.monitor.epoch_finished()
if len(x_samples) > 1:
self.monitor.update_assembly_similarity(
input_similarity=pairwise_similarity(x_samples))
[docs] def associate_benchmark(self, x_samples, learned=False):
"""
Measure `associate` operation overlap between projected assemblies
from two (or more) parent areas. Each assembly is projected
individually.
Parameters
----------
x_samples : list of tuple of torch.Tensor
The input stimuli samples list. In this case, each entry must be
a pair of vectors.
learned : bool, optional
Have the model areas been associated (True) or not? This flag
is used in the plots title only.
Default: False
"""
assert isinstance(self.model, AreaSequential)
mode_saved = self.model.training
self.model.eval()
n_parents = len(x_samples[0])
ys_traces = [] # individual traces
for parent_active in range(n_parents):
ys = []
for x_pair in x_samples:
assert isinstance(x_pair, (list, tuple))
x_active = [None] * len(x_pair)
x_active[parent_active] = x_pair[parent_active]
y, y_latent = self.model(x_active)
ys.append(y)
self.monitor.reset() # we don't need the history
ys = torch.stack(ys) # (n_samples, n_neurons)
ys_traces.append(ys)
ys_traces = torch.stack(ys_traces, dim=1) # (S, P, N)
ys_all_active = torch.stack([self.model(x_pair)[0]
for x_pair in x_samples])
self.monitor.plot_associated_activations(ys_traces=ys_traces,
ys_all_active=ys_all_active,
learned=learned)
pairwise = ys_traces.bmm(ys_traces.transpose(1, 2)) # (S, P, P)
ii, jj = torch.triu_indices(row=n_parents, col=n_parents, offset=1)
similarity = pairwise[:, ii, jj].mean()
similarity /= K_ACTIVE
learned_str = "after" if learned else "before"
self.monitor.viz.log(f"Assemblies inter-similarity {learned_str} "
f"learning: {similarity:.3f}")
if learned:
self.monitor.viz.scatter(X=[[timer.epoch + 1, similarity]], Y=[1],
win='similarity', name='A-B via C',
opts=self.monitor.viz.opts['similarity'],
update='append')
self.model.train(mode_saved)
def associate_example(n_samples=5, area_type=AreaRNNHebb):
n_stim_a, n_stim_b = N_NEURONS, N_NEURONS // 2
na, nb, nc = N_NEURONS * 2, int(N_NEURONS * 1.5), N_NEURONS
area_type = partial(area_type, p_synapse=0.05, update='multiplicative',
learning_rate=0.1)
area_A = area_type(N_NEURONS, out_features=na)
area_B = area_type(N_NEURONS // 2, out_features=nb)
area_C = area_type(na, nb, out_features=nc)
area_AB = AreaStack(area_A, area_B)
brain = AreaSequential(area_AB, area_C)
print(brain)
xa_samples = [sample_k_active(n=n_stim_a, k=K_ACTIVE) for _ in range(n_samples)]
xb_samples = [sample_k_active(n=n_stim_b, k=K_ACTIVE) for _ in range(n_samples)]
x_pairs = list(zip(xa_samples, xb_samples))
simulator = Simulator(model=brain, epoch_size=10)
simulator.simulate(x_samples=list(zip(xa_samples, [None] * n_samples)))
simulator.simulate(x_samples=list(zip([None] * n_samples, xb_samples)))
simulator.associate_benchmark(x_samples=x_pairs, learned=False)
simulator.simulate(x_samples=x_pairs)
simulator.associate_benchmark(x_samples=x_pairs, learned=True)
def simulate_example(n_samples=10):
area = AreaRNNHebb(N_NEURONS, out_features=N_NEURONS // 2)
xs = [sample_k_active(n=N_NEURONS, k=K_ACTIVE) for _ in range(n_samples)]
Simulator(model=area, epoch_size=10).simulate(x_samples=xs)
if __name__ == '__main__':
torch.manual_seed(19)
associate_example()
# simulate_example()