Call custom function on Object at every timestep

I am trying to define a custom Node object that should simulate a memristor but I’m running into trouble in trying to understand how to implement a custom learning rule along with it.

Basically, I have a Memristor() class that implements two methods: filter() and pulse(). The former function filters the input based on the current resistance state of the object and this is implemented in my Nengo model by:

memristor = nengo.Node( synapse.filter, size_in=1, size_out=1 )
nengo.Connection( A, memristor )
B = nengo.Ensemble( 100, dimensions=1 )
nengo.Connection( memristor, B )

I would now like to modify the Memristor() resistance state at runtime based on some simple learning rule that I would implement in Memristor.pulse(); for example, “IF error > 0 THEN decrease the resistance state”.
How would I go about implementing this in Nengo? I currently tried something like this:

learn_conn = nengo.Connection(err, memristor)
learn_conn.learning_rule_type = synapse.pulse()

but the pulse() function is called only once, not at every timestep as I would have expected.

Am I thinking about this in the wrong way? I am conscious that I would ideally be integrating my memristor model into a new Synapse() type but I wanted to start with a simpler proof-of-concept before delving into Nengo’s backend.

Thanks in advance!

Hi @Tioz, welcome to the forum! It’s exciting to hear that you’re looking into simulating memristors in Nengo.

I think the easiest way to get your setup working would be to have the pulse method also called in a Node. I’m assuming that synapse is an instance of Memristor. Since the two nodes are calling methods on the same instance, they can both modify its internal state.

The only potential issue with doing it with a node is that the order in which the two nodes are run within the simulation is not defined unless they depend on one another. So, if you want to make sure that the pulse method is called after the filter method, you would have to connect the memristor node to the new node you created to run pulse. Your pulse method can ignore the value that it gets from that connection, though.

Hopefully that makes sense! If not, I can make a small example of how that might look, or you can post your script here and I can take a closer look. And yes, like you say, in the end I would expect the memristor to be a Synapse type. Similarly a learning rule has to be a LearningRuleType instance, so you’ll have to learn both of those parts of the backend. If you get to that point and have any questions, we can definitely answer those here on the forum too.

Hi @tbekolay, thanks for your quick response!

I have tried implementing a simple learning system using your suggestions but I feel that I am working with wrong basic assumptions.
Mainly, I am confused about the nature of the objects I am manipulating with my “memristor” Node().

Following is the computation graph of my simple experiment; you see two memristors because I am trying to model the connection by using an “excitatory” (“Memristor+”) and an “inhibitory synapse” (Memristor-):

My main problem is that I’m confused about the inputs that the two memristors are working with; I imagine that these the encoded vectors projecting from “Pre”. As I am trying to substitute a single synapse with my setup, I imagine I should be using synapse=None on the connections in/out of the my memristor nodes.
If so, does my current setup make sense? Or should I be suppressing the decoders/encoders on the connection too? Is that even feasible at the level of abstraction of Nengo? Should I be working with the raw weights instead?

model = nengo.Network()

with model:
    inp = nengo.Node( output=lambda t: int( 6 * t / 5 ) / 3.0 % 2 - 1, size_out=1, label="Input" )
    pre = nengo.Ensemble( 1, dimensions=1, label="Pre" )
    post = nengo.Ensemble( 1, dimensions=1, label="Post" )
    err = nengo.Ensemble( 100, dimensions=1, label="Error" )
    memristor_plus = nengo.Node( exc_synapse.filter, size_in=1, size_out=1, label="Memristor+" )
    memristor_minus = nengo.Node( inh_synapse.filter, size_in=1, size_out=1, label="Memristor-" )
    pulse_plus = nengo.Node( exc_synapse.pulse, size_in=1, size_out=1, label="Pulse+" )
    pulse_minus = nengo.Node( inh_synapse.pulse, size_in=1, size_out=1, label="Pulse-" )

    nengo.Connection( inp, pre )
    nengo.Connection( pre, memristor_plus, synapse=None )
    nengo.Connection( pre, memristor_minus, synapse=None  )
    nengo.Connection( memristor_plus, post, synapse=None  )
    nengo.Connection( memristor_minus, post, synapse=None,  transform=-1 )
    nengo.Connection( pre, err, function=lambda x: x, transform=-1 )
    nengo.Connection( post, err )
    nengo.Connection( err, pulse_plus )
    nengo.Connection( err, pulse_minus )
    nengo.Connection( memristor_plus, pulse_plus )
    nengo.Connection( memristor_minus, pulse_minus )

Thanks again!

Thanks for the context! You’re right that with the way you have things set up, your connections from pre are using the encoded vectors. Am I correct in assuming that you would like to be using neural activities and raw connection weight matrices rather than decoded values and vectors? For everything except for the connections to/from error? If so, here is one way to set things up.

with model:
    inp = nengo.Node(
        output=lambda t: int(6 * t / 5) / 3.0 % 2 - 1, size_out=1, label="Input"
    pre = nengo.Ensemble(1, dimensions=1, label="Pre")
    post = nengo.Ensemble(1, dimensions=1, label="Post")
    err = nengo.Ensemble(100, dimensions=1, label="Error")
    memristor_plus = nengo.Node(
    memristor_minus = nengo.Node(
    pulse_plus = nengo.Node(
        size_in=post.n_neurons + err.dimensions,
    pulse_minus = nengo.Node(
        size_in=post.n_neurons + err.dimensions,

    nengo.Connection(inp, pre)
    nengo.Connection(pre.neurons, memristor_plus, synapse=None)
    nengo.Connection(pre.neurons, memristor_minus, synapse=None)
    nengo.Connection(memristor_plus, post.neurons, synapse=None)
    nengo.Connection(memristor_minus, post.neurons, synapse=None, transform=-1)
    nengo.Connection(pre, err, function=lambda x: x, transform=-1)
    nengo.Connection(post, err)
    nengo.Connection(memristor_plus, pulse_plus[: post.n_neurons], synapse=None)
    nengo.Connection(memristor_minus, pulse_minus[: post.n_neurons], synapse=None)
    nengo.Connection(err, pulse_plus[post.n_neurons :])
    nengo.Connection(err, pulse_minus[post.n_neurons :])

The main things that I’ve changed are:

  • The connections to/from pre and post use pre.neuron and post.neurons, which gives you the underlying neural activities rather than the decoded values. This allows you to modify the number of neurons in the pre and post ensembles while still being able to use the decoded pre and post vectors to compute error.

  • Since you can vary the number of neurons in pre and post now, I used pre.n_neurons and post.n_neurons for size_in and size_out.

  • The connections to pluse_plus and pulse_minus now send the filtered neural activites from memristor_plus and memristor_minus as the first dimension(s) and the error signal from err as the last dimension. This allows you to separate out these two inputs. If you provide both to the same dimension, they will be summed together. By separating it out into separate dimensions, you can deal with them internally in Memristor like so:

        def pulse(self, t, x):
            filtered_activities = x[: post.n_neurons]
            err = x[post.n_neurons :]

I verified that the above model works with a mocked out Memristor class:

class Memristor:
    def filter(self, t, x):
        return x
    def pulse(self, t, x):
        return x[: post.n_neurons]

exc_synapse = Memristor()
inh_synapse = Memristor()

One thing to note when you run this model is that your input function is not going to change for a relatively long time. The default timestep in Nengo is 0.001 (which represents 1 millisecond) so your input won’t change for the first ~ 0.83 simulated seconds, or around 830 timesteps. You can always raise the dt when you create the simulator with nengo.Simulator(model, dt=0.5) if you’re expecting a different timestep size.

Also, apologies for changing the code formatting, my editor auto-formats Python code and I forgot to turn that off!

1 Like

Thanks again for all your input!

First, I’d like to signal a bug I’ve found in nengo_extra.graphviz.net_diagram(): with the updated code you posted, I’m getting the following output graph:
but nengo_gui renders correctly.

Returning to my problem, with this setup the input to memristor_plus and memristor_minus are the weights associated to each neuron in the pre ensemble? Or the some sort of “neural activity” (what does this term mean in this context, exactly?)?
I have checked the output of the pre = nengo.Ensemble(100, dimensions=1, label="Pre") population and it is a (100,) vector. What does each element represent, exactly?

I’m asking this because I’m trying to understand how to actually “filter” using the memristor function:

# filter the input
def filter( self, t, x ):
    self.history.append( self.R )
    # scale weights for filtering
    w = (self.R - self.r0) / (self.r1 - self.r0)
    import time
    return w * x

Until I’m sure what x actually is, I’ll have some trouble conceptualising if what I’m doing is correct or not…

Thanks again!

Thanks for reporting the bug in net_diagram! We’re currently deciding what parts of Nengo Extras to maintain better and what parts to remove, so it’s great to know that you’ve been finding the graphviz module useful (despite the bugs). If we decide to keep it, we’ll be sure to fix it before the Nengo Extras release.

The input is the neural activity, not the weights. By “neural activity,” I am referring to the output of the neuron’s activation function, which is a function of its input current. In your network, the input current is determined by the inp node. The activation function depends on the neuron_type of the ensemble, which in your case is nengo.LIF.

The (100,) vector that you get inside the filter method is what each of the neuron is doing on that timestep (timestep t). Since you’re using the default LIF neuron type, and you’re not filtering the nengo.Connection from pre to you memristor nodes, the neural activity is always going to be either 0 or 1 / dt, where 0 is the output when the neuron doesn’t spike, and 1 / dt is the activity when it does spike. We use 1 / dt for spikes so that you can filter the neural activity of a spiking neuron to obtain an estimate of its firing rate.

In your network, there are no weights between the connection from pre.neurons to memristor_plus / memristor_minus. Well, more accurately the weights are an identity matrix. Changing the weights during the simulation would require implementing a learning rule type, but with the way things are set up, you can instead store and modify the weights inside Memristor, which is what it seems like you are doing.

For a more visual representation of what’s happening, you can see the NEF summary notebook. Cell [7] shows the neural activity of 8 neurons, over time. We draw lines for spikes, but you can imagine that inside the filter method, each of those lines corresponds to a filter call in which the value for that neuron is 1 / dt for that timestep. Cell [9] shows what happens if you add a lowpass filter to the neural activity; I presume that output of your filter method will be similar.

I am finding the graphviz output useful, as I’m coding in an external IDE and it’s great to be able to quickly visualise the network topology. So I hope the function is maintained!

So, if I’ve understood properly, every pre -> post connection can be seen as an algorithm going:

  1. pre population takes decoded vector and encodes it into spikes
  2. take spike data from each neuron in pre population
  3. filter the spike using some temporal filter i.e., a synapse
  4. multiply the filtered spikes by a weight (determined by the NEF or at runtime)
  5. input this vector to post population who sums the components to give the original signal
  6. goto 0.

So, basically, at this level I’m substituting the NEF calculation of weights with my memristors but I want to keep all the rest the same. If I wouldn’t want to implement a temporal filter just yet, could I simply keep the connections as such:
nengo.Connection( pre.neurons, memristor_plus, synapse=0.01 )
nengo.Connection( memristor_plus, post.neurons, synapse=None )
is that correct?
In that case, would this be the representation of what I would be doing and would the result be equivalent to letting to the default setup using the NEF or some other built-in learning algorithm?

My doubt remains when connecting multiple Memristor() to a post Ensemble() of neurons instead of to a single neuron. Is the reconstructed vector computed by the summing the inputs of all the neurons in the post population together?

Yes, that’s a pretty accurate assessment of what most connections are doing. Though an important thing to also keep in mind is that this whole process is happening continuously over time, so we don’t pre-encode the spikes for a certain input at the start of the simulation, they are generated at a specific moment in time based on the current decoded vector and the state of the neurons. That’s an important distinction because, unlike other types of neural networks, the output activity of the pre population is not solely determined by input (i.e., decoded vector) but also by internal state (e.g., the current membrane voltage of the LIF neuron).

Yes that’s right, then you will get input activities as filtered spike trains filtered with a 10 ms lowpass filter.

Yes, that’s what would happen by default, but importantly because you are connecting from pre.neurons, there is no NEF decoding or built-in learning algorithm happening, it’s giving you the output activities without any initial weights (just an identity matrix). You can have the pre -> post connection do an NEF decoding in weights by using a weight solver:

with net:
    nengo.Connection(pre, memristor_plus, solver=nengo.solvers.LstsqL2(weights=True))

Yes, the reconstructed vector is a weighted sum of the activities of all the neurons in the post population. The weighting is important, it’s not the case that more neural activity means higher decoded output necessarily. Also, you will see better performance if your Memristor class is also set up to emulate a population of memristors rather than creating several instances of Memristor.

I see from other posts that you’re progressing along in your model but hopefully this response is still helpful :slight_smile:

1 Like

Absolutely still relevant and I hope useful for other people too :slight_smile:

1 Like