Thank you so much for the explanation and the code @astoecke
Actually,I am trying to find out the gain that you mention here
# biases; divide by the gain (can't figure out where to reset the gain internally)
sim.data[con_direct].weights.setflags(write=True)
sim.data[con_direct].weights[...] = W_reorg / sim.data[ens_post].gain[:, None]
remove_bias_current(sim, ens_post)
For that, we need neuron properties, so,
- using PES supervised learning, tried to save the weights, doubling them in model 1 and running those double weights in the second models
- to observe the response of the neuron Iam using tuning curves.
I have been trying to implement the PES learning for the above code you mentioned, and I have a few doubts.
- I am not getting the how error population and connection works
- confused with the error node. Can you explain more about the error population?
- In model 2, I am getting an error when trying to connect Pre neurons to post-ensemble.
- I am also confused about what to mention in the function below
nengo.Connection(ens_pre, err, transform=-1, function=lambda x: x)
Here is the code which Iam trying to implement
import nengo
import numpy as np
import scipy.optimize
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from nengo.dists import Choice
from nengo.utils.ensemble import response_curves, tuning_curves
N_pre = 101
N_post = 102
N_smpls = 10001
simtime = 20
f = lambda x: x * x # Compute the square
#f = lambda x: x # Compute a communication channel
xs = np.linspace(-1, 1, N_smpls).reshape(-1, 1)
def remove_bias_current(sim, ens):
# Some fun fiddling around with Nengo-internal data-structures
signals = sim.signals
for signal in signals:
if signal.name == str(ens) + ".bias":
signal._initial_value.setflags(write=True)
signal._initial_value[...] = 0.0
def compute_J(sim, ens, xs):
data = sim.data[ens]
return data.gain * (xs @ data.encoders.T) + data.bias
def compute_A(sim, ens, Js): # "sim" is not needed, but makes the code look more symmetric
return ens.neuron_type.rates(Js, gain=np.ones(ens.n_neurons), bias=np.zeros(ens.n_neurons))
with nengo.Network() as model:
ens_pre = nengo.Ensemble(N_pre, 1, radius=2, max_rates=nengo.dists.Uniform(50, 100), label="pre")
ens_post = nengo.Ensemble(N_post, 1, radius=2, max_rates=nengo.dists.Uniform(50, 100), label="post")
nd_in = nengo.Node(lambda t: t / 10 - 1)
nengo.Connection(nd_in, ens_pre)
# Set up learning rule
con_direct = nengo.Connection(ens_pre.neurons, ens_post.neurons, transform=np.zeros((N_post, N_pre)))
con_direct.learning_rule_type = nengo.PES()
# Create error population and connections
err = nengo.Node(size_in=1)
nengo.Connection(ens_post, err)
nengo.Connection(ens_pre, err, transform=-1, function=lambda x: x)
nengo.Connection(err, con_direct.learning_rule)
# Add probes
probe_in = nengo.Probe(nd_in)
probe_pre = nengo.Probe(ens_pre, synapse=0.1)
probe_post = nengo.Probe(ens_post, synapse=0.1)
probe_err = nengo.Probe(err, synapse=0.1)
# Add probe for connection weights (decoders)
probe_weights = nengo.Probe(con_direct, "weights", sample_every=simtime)
with nengo.Simulator(model, optimize=False) as sim:
eval_points, activities = tuning_curves(ens_pre, sim)
eval_points_post, activities_post = tuning_curves(ens_post, sim)
# Need to disable optimization to be able to access the biases
# Compute the currents that we expect to be injected into the pre-population when we
# represent x
J_pre = compute_J(sim, ens_pre, xs)
A_pre = compute_A(sim, ens_pre, J_pre)
# Compute the current we need to inject into the post population when we compute
# f(x)
J_post = compute_J(sim, ens_post, f(xs))
A_post = compute_A(sim, ens_post, J_post) # Don't really need this, but useful for debugging
# Mark some pre-neurons as excitatory, and others as inhibitory
is_excitatory = np.random.choice([False, True], p=[0.3, 0.7], size=(N_pre,))
is_inhibitory = ~is_excitatory
# Use the lines below to deactivate Dale's principle
#is_excitatory = np.ones(N_pre, dtype=bool)
#is_inhibitory = np.ones(N_pre, dtype=bool)
# Split the pre-activities into an excitatory and an inhibitory matrix
A_pre_exc = A_pre[:, is_excitatory]
A_pre_inh = A_pre[:, is_inhibitory]
# Solve the NNLS problem (see p. 74 of http://hdl.handle.net/10012/17850)
# for each post-neuron individually
A = np.concatenate((A_pre_exc, -A_pre_inh), axis=1)
N_pre_total = A.shape[1] # Different from N_pre if neurons are marked both as excitatory and inhibitory
sigma = 1.0 # Regularisation factor
I_reg = N_smpls * sigma * sigma * np.eye(N_pre_total)
W = np.zeros((N_post, N_pre_total)) # transposed compared to the thesis
for i in range(N_post):
W[i] = scipy.optimize.nnls(A.T @ A + I_reg, A.T @ J_post[:, i])[0]
# Reconstruct a weight matrix of shape N_post, N_pre; the weight matrix
# computed above has all excitatory and inhibitory pre-neurons in
# separate blocks; we need to sort them back to the order they are
# in the nengo pre-population
W_reorg = np.zeros((N_post, N_pre))
i_exc, i_inh = 0, np.sum(1 * is_excitatory)
for i in range(N_pre):
if is_excitatory[i]:
W_reorg[:, i] += W[:, i_exc]
i_exc += 1
if is_inhibitory[i]:
W_reorg[:, i] -= W[:, i_inh]
i_inh += 1
# Forcefully shove these weights into the Nengo simulator; reset the
# biases; divide by the gain (can't figure out where to reset the gain internally)
sim.data[con_direct].weights.setflags(write=True)
sim.data[con_direct].weights[...] = W_reorg / sim.data[ens_post].gain[:, None]
remove_bias_current(sim, ens_post)
# Run the simulation!
sim.run(simtime)
# Save the learned decoders
weights = sim.data[probe_weights][-1]
print(weights)
weights = weights*2
print("squared weights")
print(weights)
# Plot the results
ts = sim.trange()
xs_pre = sim.data[probe_pre]
xs_post = sim.data[probe_post]
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].plot(ts, xs_pre)
axs[0].plot(ts, ts / 10.0 - 1.0, 'k--')
axs[0].set_xlabel("Time $t$")
axs[0].set_ylabel("Decoded value $x$")
axs[0].set_title("Pre-population")
axs[1].plot(ts, xs_post)
axs[1].plot(ts, f(ts / 10.0 - 1.0), 'k--')
axs[1].set_xlabel("Time $t$")
axs[1].set_ylabel("Decoded value $x$")
axs[1].set_title("Post-population")
# Extract and plot the excitatory and inhibitory weights
W_exc = W_reorg[:, is_excitatory]
W_inh = W_reorg[:, is_inhibitory]
print(W_exc)
print(W_inh)
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(W_exc, vmin=-0.01, vmax=0.01, cmap='RdBu')
axs[0].set_xlabel("Exc. pre-neuron index")
axs[0].set_ylabel("Post-neuron index")
axs[0].set_title("Excitatory weights")
axs[1].imshow(W_inh, vmin=-0.01, vmax=0.01, cmap='RdBu')
axs[1].set_xlabel("Inh. pre-neuron index")
axs[1].set_ylabel("Post-neuron index")
axs[1].set_title("Inhibitory weights")
plt.figure()
plt.plot(eval_points, activities)
# We could have alternatively shortened this to
# plt.plot(*tuning_curves(ens_1d, sim))
plt.ylabel("Firing rate (Hz)")
plt.xlabel("Input scalar, x")
plt.figure()
plt.plot(eval_points_post, activities_post)
# We could have alternatively shortened this to
# plt.plot(*tuning_curves(ens_1d, sim))
plt.ylabel("Firing rate (Hz)")
plt.xlabel("Input scalar, x")
# Model with loaded weights and no learning
with nengo.Network() as model2:
ens_pre = nengo.Ensemble(N_pre, 1, max_rates=nengo.dists.Uniform(50, 100), label="pre")
ens_post = nengo.Ensemble(N_post, 1, max_rates=nengo.dists.Uniform(50, 100), label="post")
nd_in = nengo.Node(lambda t: t / 10 - 1)
nengo.Connection(nd_in, ens_pre)
con_direct = nengo.Connection(ens_pre.neurons, ens_post.neurons, transform=weights)
probe_in = nengo.Probe(nd_in)
probe_pre = nengo.Probe(ens_pre, synapse=0.1)
probe_post = nengo.Probe(ens_post, synapse=0.1)
with nengo.Simulator(model2, optimize=False) as sim2:
eval_points, activities = tuning_curves(ens_pre, sim2)
eval_points_post, activities_post = tuning_curves(ens_post, sim2)
# Need to disable optimization to be able to access the biases
# Compute the currents that we expect to be injected into the pre-population when we
# represent x
J_pre = compute_J(sim2, ens_pre, xs)
A_pre = compute_A(sim2, ens_pre, J_pre)
# Compute the current we need to inject into the post population when we compute
# f(x)
J_post = compute_J(sim2, ens_post, f(xs))
A_post = compute_A(sim2, ens_post, J_post) # Don't really need this, but useful for debugging
# Mark some pre-neurons as excitatory, and others as inhibitory
is_excitatory = np.random.choice([False, True], p=[0.3, 0.7], size=(N_pre,))
is_inhibitory = ~is_excitatory
# Use the lines below to deactivate Dale's principle
#is_excitatory = np.ones(N_pre, dtype=bool)
#is_inhibitory = np.ones(N_pre, dtype=bool)
# Split the pre-activities into an excitatory and an inhibitory matrix
A_pre_exc = A_pre[:, is_excitatory]
A_pre_inh = A_pre[:, is_inhibitory]
# Solve the NNLS problem (see p. 74 of http://hdl.handle.net/10012/17850)
# for each post-neuron individually
A = np.concatenate((A_pre_exc, -A_pre_inh), axis=1)
N_pre_total = A.shape[1] # Different from N_pre if neurons are marked both as excitatory and inhibitory
sigma = 1.0 # Regularisation factor
I_reg = N_smpls * sigma * sigma * np.eye(N_pre_total)
W = np.zeros((N_post, N_pre_total)) # transposed compared to the thesis
for i in range(N_post):
W[i] = scipy.optimize.nnls(A.T @ A + I_reg, A.T @ J_post[:, i])[0]
# Reconstruct a weight matrix of shape N_post, N_pre; the weight matrix
# computed above has all excitatory and inhibitory pre-neurons in
# separate blocks; we need to sort them back to the order they are
# in the nengo pre-population
W_reorg = np.zeros((N_post, N_pre))
i_exc, i_inh = 0, np.sum(1 * is_excitatory)
for i in range(N_pre):
if is_excitatory[i]:
W_reorg[:, i] += W[:, i_exc]
i_exc += 1
if is_inhibitory[i]:
W_reorg[:, i] -= W[:, i_inh]
i_inh += 1
# Forcefully shove these weights into the Nengo simulator; reset the
# biases; divide by the gain (can't figure out where to reset the gain internally)
sim2.data[con_direct].weights.setflags(write=True)
sim2.data[con_direct].weights[...] = W_reorg / sim2.data[ens_post].gain[:, None]
remove_bias_current(sim2, ens_post)
# Run the simulation!
sim2.run(simtime)
# Plot the results
ts = sim2.trange()
xs_pre = sim2.data[probe_pre]
xs_post = sim2.data[probe_post]
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].plot(ts, xs_pre)
axs[0].plot(ts, ts / 10.0 - 1.0, 'k--')
axs[0].set_xlabel("Time $t$")
axs[0].set_ylabel("Decoded value $x$")
axs[0].set_title("Pre-population")
axs[1].plot(ts, xs_post)
axs[1].plot(ts, f(ts / 10.0 - 1.0), 'k--')
axs[1].set_xlabel("Time $t$")
axs[1].set_ylabel("Decoded value $x$")
axs[1].set_title("Post-population")
# Extract and plot the excitatory and inhibitory weights
W_exc = W_reorg[:, is_excitatory]
W_inh = W_reorg[:, is_inhibitory]
print(W_exc)
print(W_inh)
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(W_exc, vmin=-0.01, vmax=0.01, cmap='RdBu')
axs[0].set_xlabel("Exc. pre-neuron index")
axs[0].set_ylabel("Post-neuron index")
axs[0].set_title("Excitatory weights")
axs[1].imshow(W_inh, vmin=-0.01, vmax=0.01, cmap='RdBu')
axs[1].set_xlabel("Inh. pre-neuron index")
axs[1].set_ylabel("Post-neuron index")
axs[1].set_title("Inhibitory weights")
plt.figure()
plt.plot(eval_points, activities)
# We could have alternatively shortened this to
# plt.plot(*tuning_curves(ens_1d, sim))
plt.ylabel("Firing rate (Hz)")
plt.xlabel("Input scalar, x")
plt.figure()
plt.plot(eval_points_post, activities_post)
# We could have alternatively shortened this to
# plt.plot(*tuning_curves(ens_1d, sim))
plt.ylabel("Firing rate (Hz)")
plt.xlabel("Input scalar, x")
I am just using your code and changing things. Please correct me if I am doing anything wrong.
Thanks