So you can either do that with a learning rule, or with the hacky way I outlined. Here are examples of each.
Note that for the example with the PES learning rule, this is different than the way that we normally use PES. We typically feed the error into the PES learning rule (i.e. the actual value minus the target value) so that the learning rule adjusts weights to reduce the error. The way I’ve set things up here, the modulator
node is going straight into the node, so that a higher value will increase the weights and a lower value will decrease the weights (or maybe vice versa, since PES tries to use gradient decent to minimize the error). Also note that PES involves multiplying the error by the encoders, so the weights on individual neurons will be adjusted differently depending on their encoders.
import matplotlib.pyplot as plt
import numpy as np
import nengo
from nengo.processes import WhiteSignal
n_neurons = 100
d = 1
tsim = 10
# --- adjust weights with PES learning rule
with nengo.Network(seed=0) as net:
u = nengo.Node(WhiteSignal(period=tsim, high=0.5))
modulator = nengo.Node(lambda t: -1 if t > 5 else 0)
a = nengo.Ensemble(n_neurons, d)
b = nengo.Ensemble(n_neurons, d)
nengo.Connection(u, a, synapse=None)
c = nengo.Connection(a, b, learning_rule_type=nengo.PES(learning_rate=1e-4))
nengo.Connection(modulator, c.learning_rule)
up = nengo.Probe(u, synapse=0.03)
bp = nengo.Probe(b, synapse=0.03)
weight_p = nengo.Probe(c, "weights")
delta_p = nengo.Probe(c.learning_rule, "delta")
with nengo.Simulator(net, seed=1) as sim:
sim.run(tsim)
t = sim.trange()
plt.figure()
plt.subplot(211)
plt.plot(t, sim.data[up])
plt.plot(t, sim.data[bp])
plt.subplot(212)
plt.plot(t, np.abs(sim.data[weight_p]).sum(axis=-1))
# plt.plot(t, np.abs(sim.data[delta_p]).sum(axis=-1))
# --- adjust weights with hacky node
sim_reference = [None]
conn_reference = [None]
def modulator_fn(t):
if t > 5:
sim = sim_reference[0]
conn = conn_reference[0]
w = sim.signals[sim.model.sig[conn]["weights"]]
sim.signals[sim.model.sig[conn]["weights"]] = w + sim.dt * 0.01
with nengo.Network(seed=0) as net:
u = nengo.Node(WhiteSignal(period=tsim, high=0.5))
modulator = nengo.Node(modulator_fn, size_out=0)
a = nengo.Ensemble(n_neurons, d)
b = nengo.Ensemble(n_neurons, d)
nengo.Connection(u, a, synapse=None)
c = nengo.Connection(a, b, learning_rule_type=nengo.PES(learning_rate=1e-4))
conn_reference[0] = c
up = nengo.Probe(u, synapse=0.03)
bp = nengo.Probe(b, synapse=0.03)
weight_p = nengo.Probe(c, "weights")
with nengo.Simulator(net, seed=1) as sim:
sim_reference[0] = sim
sim.run(tsim)
t = sim.trange()
plt.figure()
plt.subplot(211)
plt.plot(t, sim.data[up])
plt.plot(t, sim.data[bp])
plt.subplot(212)
plt.plot(t, np.abs(sim.data[weight_p]).sum(axis=-1))
plt.show()