Hi there,
I’m looking to implement a new learning rule with spiking neurons using Nengo. The idea is that the network updates its neuron/ensemble activities by clamping input/outputs between weight updates while allowing the network to “relax” towards the desired/prospective activations. I have coded a simple network using just numpy below:
def initialize_weights_pcn():
return [np.ones((1, 1)), np.ones((2, 1))]
def initialize_x_pcn():
return [np.ones((1, 1)), np.zeros((1, 1)), np.array([[0], [1]])]
def train_predictive_coding(x, weights, T=120, epochs=250, learning_rate=1, gamma=0.1):
weights_history = []
for epoch in range(epochs):
# Initialize x, clamping input and output
x[0] = np.ones(1)
x[-1] = np.array([[0], [1]])
# Store weights
weights_history.append([w.copy() for w in weights])
# Relaxation
for t in range(T): # This is where I need the weights during runtime (not trying to modify them here)
epsilon = [None] + [x[l+1] - np.dot(weights[l], relu(x[l]).reshape(-1, 1)) for l in range(len(x) - 1)]
for l in range(1, len(x) - 1): # Skip the first and last neurons
Delta_x = gamma * (-epsilon[l] + relu_derivative(x[l]) * np.dot(weights[l].T, epsilon[l + 1]))
x[l] += Delta_x
# Update weights
for l in range(len(x) - 1):
Delta_w = learning_rate * np.outer(epsilon[l + 1], relu(x[l]))
weights[l] += Delta_w if Delta_w.shape == weights[l].shape else Delta_w.T
return weights_history
In order to do this, I’m aware I need to create my own learning rule type, which as you can see in the weight update is similar to the PES rule. One last detail is that I want to run the weight and activity updates simultaneously (with the weight updates occurring once every 120 activity updates or so). I have attached what I have attempted so far, but I’m encountering issues with the attempts:
- I need the error calculated between the pre and post connection of a given neuron as well as their weights to compute the updates.
- Updating the activities in the operator class but i saw that only 1 update here which would be used for delta
- I have considered doing the whole thing at runtime, using nodes instead of working within the operator, however, I can’t access the weights needed to compute the error ε or the Δx.
PCN_simple_network.py (4.6 KB)