Surrogate gradient training under NengoDL

Hello, I have a question regarding fine-tuning SNNs under NengoDL.

By fine-tuning, I am wondering if it is possible to update the weights of the SNN after the ANN has been trained and converted. To do this for example, we could train the SNN for a few epochs using the real spiking function (IF/LIF) in the forward pass and then switching to the surrogate, or “soft” function as Nengo calls it, in the backwards pass and then update the weights accordingly.

This surrogate gradient/STDB approach has been done before but I am curious if you think this would be a possible thing to do under NengoDL. Please let me know what you think of this idea and I look forward to hearing your input, thanks.

Hi @mbaltes,

I forwarded your question to the NengoDL devs, and they inform me that what you are looking to do is technically possible. As a bit of background information, when you train a SNN in NengoDL, it uses the rate activation functions for both the forward and backwards passes. To modify this behaviour, you’ll need to create a custom neuron & builder classes (or monkey-patch the existing neuron builder class) that returns the appropriate tensor in the training_step function of the associated neuron builder.

Inside the step function, you’ll want something like this:

forward = self.step(...)  # Uses the spiking step function for the forward pass
backward = LIFRateBuilder.step(...)  # Uses the rate step function for the backwards pass
return backward + tf.stop_gradient(forward - backward) 

I’m not too familiar with the tf.stop_gradient function, so here’s a quote from one of the devs that may provide some insight into this method:

def pseudo_gradient(forward, backward):
    """Selects between the forward or backward tensor."""
    # following trick is courtesy of drasmuss/hunse:
    # on the forwards pass, the backward terms cancel out;
    # on the backwards pass, only the backward gradient is let through.
    # this is similar to using tf.custom_gradient, but that is
    # error-prone since certain parts of the graph might not
    # exist in the context that any tf.gradients are evaluated, which
    # can lead to None gradients silently getting through
    # note: may be possible to optimize this a bit
    return backward + tf.stop_gradient(forward - backward)

As a side note, the SpikingKeras python package allows you to do spike-aware training within Keras (TF) itself (see this example). Thus, another approach would be to train your network within Keras (using SpikingKeras) and then using the NengoDL converter on that. Although, the caveat here is that I haven’t tested this method, so there may be some debugging needed to figure out the appropriate approach.