Help with Spiking Generative Language Models

Hi everyone, I am currently working on a research project involving spiking neural networks and generative language models. I am collaborating with my teacher to explore the potential of SNNs for natural language generation.

I have been working on adapting my existing TensorFlow model to use NengoDL for spiking neural networks. I would appreciate any guidance or resources that could help me in this transition.

Specifically, I am looking for:

  1. Best practices for converting TensorFlow models to NengoDL.
  2. Examples of spiking neural network implementations for natural language processing.
  3. Any tips on optimizing SNNs for text generation tasks.
  4. How can I have my code run faster since I am doing this on a PC and it takes too long?

Thank you!

  1. Here are some helpful guides:
    Coming from TensorFlow to NengoDL — NengoDL 3.6.1.dev0 docs
    Converting a Keras model to a spiking neural network — NengoDL 3.6.1.dev0 docs

  2. One example I can think of is
    CNRGlab @ UWaterloo | Publications
    Maybe this would be of interest to you as well: CNRGlab @ UWaterloo | Publications. It doesn’t use spiking neural networks but it is based on something called the Legendre Memory Unit, which has been implemented in spiking neural networks & nengo before. Here’s how you can create a spiking LMU nengo networks:

 import numpy
 import nengo
 from scipy.special import legendre
 # n_neurons: num of neurons per memory vector
 # theta: length of time window (in seconds)
 # q: LMU memory vector size
 # size_in: dim of signal to remember
 class LMUNetwork(nengo.Network):
     def __init__(self, n_neurons, theta, q, size_in=1, tau=0.05,**kwargs):
         super().__init__()
         self.q = q              # number of internal state dimensions per input
         self.theta = theta      # size of time window (in seconds)
         self.size_in = size_in  # number of inputs
         # Do Aaron's math to generate the matrices
         #  https://github.com/arvoelke/nengolib/blob/master/nengolib/synapses/analog.py#L536
         Q = np.arange(q, dtype=np.float64)
         R = (2*Q + 1)[:, None] / theta
         j, i = np.meshgrid(Q, Q)
         self.A = np.where(i < j, -1, (-1.)**(i-j+1)) * R
         self.B = (-1.)**Q[:, None] * R
         with self:
             self.input = nengo.Node(size_in=size_in)
             self.reset = nengo.Node(size_in=1)
             self.lmu = nengo.networks.EnsembleArray(n_neurons, n_ensembles=size_in,
                                                     ens_dimensions=q, **kwargs)
             self.output = self.lmu.output           
             for i in range(size_in):
                 nengo.Connection(self.input[i], self.lmu.ea_ensembles[i], synapse=tau,
                                  transform = tau*self.B)
                 nengo.Connection(self.lmu.ea_ensembles[i], self.lmu.ea_ensembles[i], synapse=tau,
                                  transform = tau*self.A + np.eye(q))
                 nengo.Connection(self.reset, self.lmu.ea_ensembles[i].neurons, transform = [[-2.5]]*n_neurons, synapse=None)
   # Use:
 theta = 0.2
 q=10
 n_neurons=800
 tau=0.03
 prb_syn=0.01
 model = nengo.Network()
 with model:
     inp = nengo.Node(nengo.processes.WhiteSignal(2, high=5, rms=0.3, seed=1))
     ldn = LMUNetwork(n_neurons, theta=theta, q=q, size_in=1, tau=tau)
     nengo.Connection(inp, ldn.input, synapse=None)
     in_p = nengo.Probe(inp, synapse=None)
     lmu_p = nengo.Probe(ldn.output, synapse=prb_syn)
     recall = nengo.Node(size_in=1)
     delay_matrix = np.kron(np.eye(1),np.asarray([legendre(i)(2*r - 1) for i in range(q)]).reshape(self.q, -1).T)
     nengo.Connection(ldn.output, recall, transform = delay_matrix, synapse=prb_syn)
     recall_p = nengo.Probe(recall, synapse=None)
 with nengo.Simulator(model,dt=0.001) as sim:
     sim.run(1.5)
 ts = sim.trange()
 plt.figure(figsize=(9,3))
 plt.subplot(1,2,1)
 plt.title("Recall from SNN LDN")
 plt.plot(ts, sim.data[in_p])
 plt.plot(ts, sim.data[recall_p])
 plt.legend(['signal','delayed'])
 plt.subplot(1,2,2)
 plt.plot(ts, sim.data[in_p])
 plt.plot(ts - theta , sim.data[recall_p])
 plt.legend(['signal','delayed, shifted'])
  1. Both nengo-dl and nengo-ocl use GPU backends so that should speed up simulation run time.