The problem is that you cannot change the learning rule and rebuild, because there are references to the old learning rule (e.g. in the probes) that don’t get changed.
I think the easiest solution is to find/make a way to change the learning rate on the existing learning rule itself. Since this is a custom learning rule, you already have full control over how it’s built. I think the easiest is to make the learning_rate
in the step function a reference to the learning_rate
on the operator you create, and then change the learning rate on the operator. Here’s an example:
import numpy as np
import nengo
from nengo.builder.builder import Builder
from nengo.builder.learning_rules import build_or_passthrough, get_post_ens, get_pre_ens
from nengo.builder.operator import Operator
class MyOja(nengo.Oja):
pass
class MySimOja(Operator):
def __init__(
self, pre_filtered, post_filtered, weights, delta, learning_rate, beta, tag=None
):
super().__init__(tag=tag)
self.learning_rate = learning_rate
self.beta = beta
self.sets = []
self.incs = []
self.reads = [pre_filtered, post_filtered, weights]
self.updates = [delta]
@property
def delta(self):
return self.updates[0]
@property
def pre_filtered(self):
return self.reads[0]
@property
def post_filtered(self):
return self.reads[1]
@property
def weights(self):
return self.reads[2]
@property
def _descstr(self):
return f"pre={self.pre_filtered}, post={self.post_filtered} -> {self.delta}"
def make_step(self, signals, dt, rng):
weights = signals[self.weights]
pre_filtered = signals[self.pre_filtered]
post_filtered = signals[self.post_filtered]
delta = signals[self.delta]
beta = self.beta
def step_simoja():
alpha = self.learning_rate * dt
print(f"Alpha: {alpha}")
# perform forgetting
post_squared = alpha * post_filtered * post_filtered
delta[...] = -beta * weights * post_squared[:, None]
# perform update
delta[...] += np.outer(alpha * post_filtered, pre_filtered)
return step_simoja
@Builder.register(MyOja)
def build_oja(model, oja, rule):
conn = rule.connection
pre_activities = model.sig[get_pre_ens(conn).neurons]["out"]
post_activities = model.sig[get_post_ens(conn).neurons]["out"]
pre_filtered = build_or_passthrough(model, oja.pre_synapse, pre_activities)
post_filtered = build_or_passthrough(model, oja.post_synapse, post_activities)
model.add_op(
MySimOja(
pre_filtered,
post_filtered,
model.sig[conn]["weights"],
model.sig[rule]["delta"],
learning_rate=oja.learning_rate,
beta=oja.beta,
)
)
# expose these for probes
model.sig[rule]["pre_filtered"] = pre_filtered
model.sig[rule]["post_filtered"] = post_filtered
seed = 0
time = 0.01
neurons = 10
train_neurons = [1] * neurons
model = nengo.Network(seed=seed)
with model:
inp = nengo.Node(
lambda t, x: train_neurons if t < time / 2 else [-2] * neurons, size_in=1
)
ens = nengo.Ensemble(neurons, 1)
nengo.Connection(inp, ens.neurons, seed=seed)
conn = nengo.Connection(
ens.neurons,
ens.neurons,
learning_rule_type=MyOja(),
transform=np.zeros((ens.n_neurons, ens.n_neurons)),
)
ens_probe = nengo.Probe(ens.neurons)
weight_probe = nengo.Probe(conn, "weights")
pre_filt_probe = nengo.Probe(conn.learning_rule, "pre_filtered")
post_filt_probe = nengo.Probe(conn.learning_rule, "post_filtered")
with nengo.Simulator(model, progress_bar=False) as sim:
sim.run(time / 2)
weights_sig = sim.model.sig[conn]["weights"]
ops = [
op
for op in sim.model.operators
if isinstance(op, MySimOja) and op.weights == weights_sig
]
assert len(ops) == 1
op = ops[0]
op.learning_rate = 0
sim.run(time / 2)