Surrogate gradient training under NengoDL

Hello, I have a question regarding fine-tuning SNNs under NengoDL.

By fine-tuning, I am wondering if it is possible to update the weights of the SNN after the ANN has been trained and converted. To do this for example, we could train the SNN for a few epochs using the real spiking function (IF/LIF) in the forward pass and then switching to the surrogate, or “soft” function as Nengo calls it, in the backwards pass and then update the weights accordingly.

This surrogate gradient/STDB approach has been done before but I am curious if you think this would be a possible thing to do under NengoDL. Please let me know what you think of this idea and I look forward to hearing your input, thanks.

Hi @mbaltes,

I forwarded your question to the NengoDL devs, and they inform me that what you are looking to do is technically possible. As a bit of background information, when you train a SNN in NengoDL, it uses the rate activation functions for both the forward and backwards passes. To modify this behaviour, you’ll need to create a custom neuron & builder classes (or monkey-patch the existing neuron builder class) that returns the appropriate tensor in the training_step function of the associated neuron builder.

Inside the step function, you’ll want something like this:

forward = self.step(...)  # Uses the spiking step function for the forward pass
backward = LIFRateBuilder.step(...)  # Uses the rate step function for the backwards pass
return backward + tf.stop_gradient(forward - backward) 

I’m not too familiar with the tf.stop_gradient function, so here’s a quote from one of the devs that may provide some insight into this method:

def pseudo_gradient(forward, backward):
    """Selects between the forward or backward tensor."""
    # following trick is courtesy of drasmuss/hunse:
    # on the forwards pass, the backward terms cancel out;
    # on the backwards pass, only the backward gradient is let through.
    # this is similar to using tf.custom_gradient, but that is
    # error-prone since certain parts of the graph might not
    # exist in the context that any tf.gradients are evaluated, which
    # can lead to None gradients silently getting through
    # note: may be possible to optimize this a bit
    return backward + tf.stop_gradient(forward - backward)

As a side note, the SpikingKeras python package allows you to do spike-aware training within Keras (TF) itself (see this example). Thus, another approach would be to train your network within Keras (using SpikingKeras) and then using the NengoDL converter on that. Although, the caveat here is that I haven’t tested this method, so there may be some debugging needed to figure out the appropriate approach.

Hello, thank you for your reply and for reaching out to the developers.

I think I will first try looking into the SpikingKeras package and see if I would be able to use the converter on that before I try to modify the source code. Like you said there may be some debugging so if that is the case I will be sure to let you know how that goes.

1 Like

@xchoo I have been going the the KerasSpiking examples and I tried applying it to a small model I have but I have not gotten any luck and I was wondering if you could take a look and tell me what is the problem.

I have a simple convolutional denoising auto encoder that uses the MNIST dataset. This is the network I have been using:

input = layers.Input(shape=(28, 28, 1))

# Encoder
x = layers.Conv2D(32, (3, 3), activation=tf.nn.relu, padding="same")(input)
x = layers.MaxPooling2D((2, 2), padding="same")(x)
x = layers.Conv2D(32, (3, 3), activation=tf.nn.relu, padding="same")(x)
x = layers.MaxPooling2D((2, 2), padding="same")(x)

# Decoder
x = layers.Conv2DTranspose(32, (3, 3), strides=2, activation=tf.nn.relu, padding="same")(x)
x = layers.Conv2DTranspose(32, (3, 3), strides=2, activation=tf.nn.relu, padding="same")(x)
out = layers.Conv2D(1, (3, 3), activation="sigmoid", padding="same")(x)

model = Model(inputs=input, outputs=out)

For the keras spiking model I changed the maxpooling layers to conv2d layers w/ stride 2 and changed it to follow the example in the documentation and this is what I came up with so far:

keras_spiking.default.dt = 0.01
LOWPASS = 0.1

filtered_model = tf.keras.Sequential(
    [
        # input 
        layers.Reshape((-1, 28, 28, 1), input_shape=(None, 28, 28)),
        # conv1
        layers.TimeDistributed(layers.Conv2D(32, (3, 3), padding='same', name='conv1')),
        keras_spiking.SpikingActivation("relu", spiking_aware_training=True),
        keras_spiking.Lowpass(LOWPASS, return_sequences=False),
        # max pool
        layers.TimeDistributed(layers.Conv2D(32, 2, 2, padding='valid', name='pool1')),
        keras_spiking.SpikingActivation("relu", spiking_aware_training=True),
        keras_spiking.Lowpass(LOWPASS, return_sequences=False),
        # conv2
        layers.TimeDistributed(layers.Conv2D(32, (3, 3), padding='same', name='conv2')),
        keras_spiking.SpikingActivation("relu", spiking_aware_training=True),
        keras_spiking.Lowpass(LOWPASS, return_sequences=False),
        # max pool
        layers.TimeDistributed(layers.MaxPooling2D((2, 2), padding='same', name='pool2')),

        # deconv1
        layers.TimeDistributed(layers.Conv2DTranspose(32, (3, 3), strides=2, padding='same', name='deconv1')),
        keras_spiking.SpikingActivation("relu", spiking_aware_training=True),
        keras_spiking.Lowpass(LOWPASS, return_sequences=False),
        # deconv2
        layers.TimeDistributed(layers.Conv2DTranspose(32, (3, 3), strides=2, padding='same', name='deconv1')),
        keras_spiking.SpikingActivation("relu", spiking_aware_training=True),
        keras_spiking.Lowpass(LOWPASS, return_sequences=False),

        # output
        layers.Conv2D(1, (3, 3), activation="sigmoid", padding="same")
    ]
)

However, this results in a dimension error in the first pooling layer:

ValueError: Input 0 of layer pool1 is incompatible with the layer: : expected min_ndim=4, found ndim=3. Full shape received: (None, 28, 32)

I tried flattening the input like in the keras spiking example but this lead to a dimension error in the first convolutional layer.

I’ve been looking for examples that involve using TimeDistributed with a Conv2D layer (like in this example here: TimeDistributed layer) but I tried do the same thing and I haven’t had any luck in successfully creating the model for training.

If you have the time I would greatly appreciate if you could look into why this model is not working with the keras spiking package. If you can’t find anything I can move onto modifying the source code like you said. I look forward to hearing back from you, thanks.

Hi @mbaltes,

Unfortunately, I don’t have much experience with KerasSpiking to immediately identify what is causing the shape mismatch issue. I’ll forward your questions to the KerasSpiking devs and will update when I have a response! :grinning:

If I were to guess though, it might be that this is a typo, and that this is the cause of the issue:

# should this be (2, 2) instead of 2, 2?
layers.TimeDistributed(layers.Conv2D(32, 2, 2, padding='valid', name='pool1'))

Or that the signals need flattening or unflattening somewhere.

The devs got back to me and informed me that the reason for the error is because of the Lowpass layer has been misconfigured. In order for the network to be spiking, it needs to be time aware (i.e., there needs to be a time dimension). This is why the Conv2D layers are wrapped in the TimeDistributed layer. By default, the Lowpass layer is also time aware, and you can connect two time aware layers without any changes to the code.

But, in your code, you have configured the Lowpass layer as return_sequences=False. This makes the Lowpass layer only return the last value in the sequence (i.e., collapses the time dimension down to the last value in the sequence). While this makes sense for the output layer of the network, if you were to do this and try to connect it to a time aware layer, you’ll encounter the error as mentioned. Essentially, you are trying to connect a non-time-aware layer to a time-aware layer.

Fixing this error is straightforward, simply set return_sequences=True for the Lowpass layers (or leave out the argument entirely, return_sequences=True is the default value).

You can read more about the return_sequences parameter here, and here (see the blurb below).

Note: I still think the 2, 2 vs (2, 2) is a typo.

@xchoo thank you for your response. I must have missed that when reading through the examples. After making the small changes you mentioned I was able to get the SNN working. However, my results are not that great.

The ANN’s predicted images resulted in an MSE of about 0.007
The SNN which was converted using NengoDL’s converter resulted in an MSE of about 0.013
This SNN which was trained using the spiking aware training resulted in an MSE of about 49.083

I thought that this was surprising because I tried to configure the hyperparameters in the spiking aware SNN as close as possible to the NengoDL model.

In the NengoDL SNN I had these hyperparameters:

n_steps: 30
synapse: 0.01
activation: nengo.SpikingRectifiedLinear()
scale_firing_rates: 500

In the KerasSpiking SNN I tried to mimic them as closely as possible:

n_steps: 30
lowpass_tau: 0.01  # im assuming this works the same was as the synapse parameter does so I kept it the same
activation: keras_spiking.SpikingActivation("relu", spiking_aware_training=True)
# as for scaling the firing rates, I tried with and without L2 percentile regularization
# (activity_regularizer=keras_spiking.regularizers.Percentile(l2=0.01, target=(100, 500)) 
# between 100-500Hz but I did not get the best results

I know that you mentioned not being too familiar with the KerasSpiking but if you, or the devs, would be able to identify what is causing this model to not perform well it would greatly appreciated

So I have been looking into my network as well as other KerasSpiking functions and got some OK results.

I tried using the DtScheduler to decay the dt value over time during training but the results did not come out all that great

I then set the dt to just be 1.0 throughout all of training and got some better results:
MSE: 0.0407 (previous was 49.083)

Obviously we don’t want to use a high dt during training because we lose the temporal sparsity advantage that SNNs offer. I will be looking into some other solutions to see if I can drop the dt value down while still achieving good results

1 Like