Spiking LMUs are still an area of active research, and we don’t have any tried-and-true methods of training them.
One method that has seen success in the past is using a “hybrid spiking” approach, where we start with something that is similar to a non-spiking (rate) neuron, and gradually transition to a spiking neuron model. This is described in this paper.
While it’s possible that there is a bug in your code that is preventing correct training, it is also possible that you need to use a training method more like this “hybrid spiking” approach. To distinguish between these two cases, I would first try replacing all spiking neurons in your model with rate neurons (e.g. using
nengo.RectifiedLinear instead of
If that still doesn’t train, then I would go back a step further and try something in between the original example with all nodes, and your example with all ensembles. For example, you might want to just make the
m node an ensemble with
nengo.RectifiedLinear neurons, and see if that works.
Once you get things working with rate neurons, there are a couple considerations when transitioning to spiking (in addition to using the “hybrid spiking” approach referenced above). The first is that you may need synapses on some of the connections, to help smooth out the spikes and reduce the amount of “noise” (you’ll notice in the original example, all synapses are either
None, meaning they provide no filtering). If you do add in these synapses, you’ll also want to change the A and B matrices to account for this; the details of how to do this are in Section 3.3 of the “hybrid spiking” paper. Adding in synapses will also change the dynamics of how quickly the network can respond to changes in the input. To account for this, you may want to show each element in the input for more than one timestep, which you can do by repeating the elements for a number of timesteps (e.g. if you’ve got your
images, where the first axis is the batch axis, the second is the time axis and the third is the dimension, do
images = images.repeat(5, axis=1) to repeat each element 5 times along the time axis).