Evaluating spiking autoencoder

To get started working with spiking autoencoders, I’ve been trying to follow the spiking mnist example as well as other examples for Nengo and NengoDL by making a simple autoencoder for producing the MNIST digits. I’ve set up my network and I believe it is being trained properly, since the values for the training loss look right. It works with regular neurons, but I’m having issues with evaluation when I switch to spiking neurons.

I’m confused about how the evaluation works for a spiking autoencoder model. I changed my test data to repeat for a number of time steps. I thought it would be as straightforward as simply calling sim.evaluate(test_data, test_data) since my target is the same as the input, but this doesn’t seem to work. I then try to display a sample picture by following the same approach as the non-spiking autoencoder, i.e.

output = sim.predict(test_data[:minibatch_size])
plt.imshow(output[p_c][29].reshape((28, 28)))

When I do that, I get the following error: ValueError: cannot reshape array of size 23520 into shape (28,28). So… something is definitely very wrong. What I notice here is that I’m using 30 time steps, and this array size of 23520 = 784*30, so I must not be collecting the right data to reshape and display.

I get that I want to evaluate at the final time step of the simulation since it’s a spiking network, but I’m not quite sure how to do that and display a sample reconstruction. I’m a bit unable to follow the explanation in the spiking mnist example and would appreciate a bit more clarity on how exactly we go about setting up the evaluation with spiking neurons given the added temporal factor. Thank you!

1 Like

Hello Khanus,

It sounds like you’re close, if you post your code I can take a look and provide some more specific feedback on what’s going wrong.

I think you have the right idea of the general approach for validation: feed in an input signal to the spiking network, let it run for some number of time steps, and then look at the output at the last time step and compare to the target output. I suspect there may be an issue with the way that test_data was set up, but I’ll be able to say for sure when I see the code!

Here is how I first load the data. I’m only using images here as the data, ignoring the labels.

#download mnist dataset
(train_data, _), (test_data, _) = tf.keras.datasets.mnist.load_data() 

#flatten images
train_data = train_data.reshape((train_data.shape[0], -1))
test_data = test_data.reshape((test_data.shape[0], -1))

The network definition:

with nengo.Network(seed = 0) as auto_net:
    auto_net.config[nengo.Ensemble].max_rates = nengo.dists.Choice([100])
    auto_net.config[nengo.Ensemble].intercepts = nengo.dists.Choice([0])
    auto_net.config[nengo.Connection].synapse = None # this disables synaptic filtering

    n_type = nengo.PoissonSpiking(nengo.RectifiedLinear())
    n_in = 784
    inter_dim = 128

    inp_node = nengo.Node(np.zeros(n_in))

    # first layer
    enc1 = nengo.Ensemble(inter_dim, 1, neuron_type = n_type)
    nengo.Connection(inp_node, enc1.neurons, transform=nengo_dl.dists.Glorot())

    # second layer
    enc2 = nengo.Ensemble(inter_dim, 1, neuron_type = n_type)
    nengo.Connection(enc1.neurons, enc2.neurons, transform=nengo_dl.dists.Glorot())

    # output layer
    outp = nengo.Ensemble(n_in, 1, neuron_type=n_type)
    nengo.Connection(enc2.neurons, outp.neurons, transform=nengo_dl.dists.Glorot())

    # probes
    p_c = nengo.Probe(outp.neurons)

Here is how I set up the data for training and testing:

    train_data = train_data[:, None, :]
    n_steps = 30
    test_data = np.tile(test_data[:, None, :], (1, n_steps, 1))

Training and validation (minibatch_size is 50):

with nengo_dl.Simulator(auto_net, minibatch_size=minibatch_size) as sim:
  sim.compile(optimizer = tf.optimizers.RMSprop(1e-3), 
  loss = tf.losses.mse) # mean squared error as loss function

  # run training loop. If using spiking neurons, this converts the model to a rate based approximation and then does training
  sim.fit(train_data, train_data, epochs = 10)

  # evaluate performance on test set
  print("Test error: ", sim.evaluate(test_data, test_data))

  # display example output of a digit reconstruction
  output = sim.predict(test_data[:minibatch_size])
  plt.imshow(output[p_c][29].reshape((28, 28)))

Thanks for your help!

Ah, so if you add a print(output[p_c].shape) line to you code, you’ll see that the array is actually of the shape [minibatch_size, timesteps, dimensions]. So to do the reshape for plotting you’ll need to choose a minibatch to plot.

plt.imshow(output[p_c][0, 29].reshape((28, 28)))

should get you what you want!

Thanks, that fixed the error and plotted something. But I am confused by the result:

This is a pretty simple task and the the numbers for the training error look right. I first did the task with rectified linear neurons and it worked great. Pretty much all I’ve changed is the neuron type to a PoissonSpiking constructed from a rectified linear neuron, and corresponding changes to the data to account for the temporal nature of an SNN. To my understanding, since the training is done on a rate based version of the network and not on the spiking network directly, the network should perform comparably in both cases. I must still be doing something wrong here but I’m not sure what it is. Since the test error is really high while training seems fine, my thought is there’s a problem with how I’m using sim.evaluate and sim.predict:

print("Test error: ", sim.evaluate(test_data, test_data))
output = sim.predict(test_data[:minibatch_size])

Do I have to make a change here too? I had thought that since I added the temporal dimension to the test data earlier I wouldn’t have to change anything here. Thanks so much!

Hm, can you send me your current script in full (i.e attach the file, with all imports etc included)?
I’m having trouble reproducing your results on my side and want to make sure I’m debugging from the same place.

nengo_autoencoder.ipynb (26.3 KB)

Here’s the jupyter notebook. I thought about it a little more and I’m wondering if perhaps the issue is with some of the initial values I’m setting, or because I’m not filtering the output. Thanks for your help!

Hi Khanus,

you were very close! The culprit was two issues, both dealing with synapses. Please find attached a notebook with the corrections.

The first issue was the probe on the output. The default synapse for Probes is None. But this means that in your example you would have to have all the neurons firing at the exact same time to get the decoding that you want to see. Because neurons have temporal dynamics (like refractory periods where they can’t spike again for a small amount of time after spiking) unless you set the firing rates of your neurons super high they won’t be spiking every time step. This means that you need to do a bit of filtering to average the output over a few time steps. To do this, we create another probe with a filter, and we read from that filter when we want to plot the decoded spiking output.

The other thing that we can do is set the synapses on the rest of the connections in the network to not be None. This better means their output is also averaged over time, so they better approximate the rate neuron output used during training, giving us a better match between the rate and spiking implementations. To change the synapses after training, I use the sim.freeze_params(auto_net) call, which stores the trained weights back in the network. I can then modify the connections and create another sim object with the version of the network that has both the trained weights and synapses.

Does that make sense?

nengo_autoencoder.ipynb (17.9 KB)

Hi Travis,

Awesome, thanks so much for your help! I think I understand, I felt like I was missing something about the temporal dynamics and that makes a lot of sense. With regards to the filtered probe, what exactly is being filtered out? Or is its function just to collect average values across time so we can have all the spikes we need? I think I assumed that having it set to none would be like having a filter with gain 1 but in hindsight that doesn’t make that much sense :sweat_smile:

In the same vein, I have a couple questions about the filtering. How did you choose the values, i.e. 0.01 for the probe and 0.005 for the other synapses in the network? Are there any places I can read more about the role synaptic filtering is playing here?

Thanks for going back and forth with me on this so much!

No problem!

Ah also sorry I’m using the term filter in the signal processing sense, where the high frequency components of the signal are what gets filtered out. Are you familiar with a low-pass filter? The default synapse in Nengo is a low-pass filter. You can think of it as the larger the time constant on the low-pass filter, the longer the signal is averaged over, the smoother the signal gets.

0.005 is the default Nengo synapse filter, so that’s how I chose it for the synapses on the network connections. For the 0.01 on the output, that one would have been chose (probably) by just trying out a few and seeing what gives the best results.