Manually calculate error between Ensembles from their outputs

I have the following setup where a learn node is calculating a weight matrix to learn to approximate a function between the pre and post ensembles:

learn = nengo.Node( memr_arr,
                    size_in=pre_nrn + pre.dimensions,
                    size_out=post_nrn,
                    label="Learn" )

nengo.Connection( inp, pre )

nengo.Connection( pre.neurons, learn[ :pre_nrn ], synapse=0.01 )
nengo.Connection( pre, learn[ pre_nrn: ], synapse=0.01 )
nengo.Connection( learn, post.neurons, synapse=None )

Instead of having an external ensemble calculate the error and project it back to learn, I would like to calculate the error directly inside the learn node, in order to not have to worry about network delays.

How would I go about doing this? Inside learn I could have:

def __call__( self, t, x ):
    input_activities = x[ :self.input_size ]
    ground_truth = self.function_to_learn( x[ self.input_size: ] )
    
    # query each memristor for its resistance state
    extract_R = lambda x: x.get_state( t, value="conductance", scaled=True )
    extract_R_V = np.vectorize( extract_R )
    weights = extract_R_V( self.memristors )
    # calculate the output at this timestep
    return_value = np.dot( weights, input_activities )
    
    # calculate error
    error = return_value - ground_truth
    self.error_history.append( error )

Does that seem correct or am I still having to deal with some delay in the error signal?
return_value is the weight matrix so I can’t simply say error = return_value - ground_truth, how would I actually compute the value that post would represent in order to calculate the error?

Am I going about this all wrong? I also thought of projecting back from post to learn but that again would introduce a certain delay that wouldn’t be easily quantifiable.
Another idea would be to introduce a buffer in my learn node to accumulate error signals and do a delayed update when only the error signal for a given timestep arrives.

How does the PES rule deal with the delay in the error when modulating a connection?

A certain amount of delay is inherent in any online learning rule, by definition, because you’re using the input/output at timestep t to calculate an error signal, which will cause a change in weights, which will affect the output, resulting in a new error signal, and so on. So at some point in that loop there is going to be a one timestep delay. But you can choose where that delay is.

Note that using an external ensemble doesn’t automatically mean there will be network delays. E.g., you could have

    ens0 = nengo.Ensemble(10, 1)
    ens1 = nengo.Ensemble(10, 1)
    learn_conn = nengo.Connection(
        ens0, ens1, learning_rule_type=nengo.PES(), synapse=None
    )
    error = nengo.Node(size_in=1)
    nengo.Connection(ens0, error, synapse=None, transform=-1)
    nengo.Connection(ens1, error, synapse=None, transform=1)
    nengo.Connection(error, learn_conn.learning_rule, synapse=None)

This network has no network delays (ensured by setting synapse=None throughout). So on timestep t, the error node will be comparing the output of ens0 and ens1 from time t, computing an error signal (also at time t), and computing a change in weights (also at time t). However, the weights won’t actually change until timestep t+1. So that is where the one timestep delay comes into that system.

Alternatively, you could implement it so that the error node is comparing the output of ens0 and ens1 at time t-1, and then computing an error signal at time t and updating the weights at time t as well, so that it immediately affects the output of ens1 at time t. In this case you are adding the one timestep delay on the input to the error calculation. I believe this is how you have things set up in your code above.

Or you could add the delay on the output of the error signal (so that the learning rule looks at the output of error from t-1, and then immediately applies a weight update).

But what you cannot do is do all of that on the same timestep. Because your error signal depends on the output of ens1, which in turn depends on the weights between ens0 and ens, so if you tried to set up a system where everything is computed and on the same timestep you have a circular dependency (the change in weights depends on the output of ens1, but the output of ens1 depends on the change in weights). That’s just a basic feature of online learning rules, not particular to Nengo.

So what you need to do is figure out where in your learning rule loop you want that delay to be, and implement things appropriately. Depending on where you want that delay to be, you may be able to just use normal Nengo connections to implement your error calculation, or you may need to implement it internally in the Node.

1 Like

Does the delay depend on the synapse parameter i.e., does synapse=0.1 introduce a 1ms delay on that connection? Or just of 1 timestep, independent of the filter used?

Yes, I think I’ll try to go with this. What would be the correct way to take the decoded signal from the ens1 and compare it with that of ens0? Is it as simple as doing a subtraction between them or would these be two arrays?

synapse=None is no delay, and synapse=0 introduces an exact one-timestep delay. Synapses > 0 are spreading the information across multiple future timesteps (by applying a lowpass filter). So it’s not exactly a fixed delay, you will have more information at timestep t+1 and less information at timestep t+2, t+3, etc., proportional to the magnitude of the filter.

That depends on your error function (i.e., what comparison you want to make between ens0 and ens1). Subtraction is one common error function, but not the only one. There’s no “correct” error function, it just depends on the mathematics of your learning rule.