In [None]:
import nengo
import numpy as np
import tensorflow as tf

import nengo_dl

seed = 0
np.random.seed(seed)
tf.random.set_seed(seed)

def get_model(include_kr=False):
    inp = tf.keras.Input(shape=(28, 28, 1))
    
    # convolutional layers
    if include_kr:
        conv0 = tf.keras.layers.Conv2D(
            filters=32,
            kernel_size=3,
            activation=tf.nn.relu,
            kernel_regularizer=tf.keras.regularizers.l2(1e-3),
        )(inp)
        
        conv1 = tf.keras.layers.Conv2D(
            filters=64,
            kernel_size=3,
            strides=2,
            activation=tf.nn.relu,
            kernel_regularizer=tf.keras.regularizers.l2(1e-3),
        )(conv0)
    else:
        conv0 = tf.keras.layers.Conv2D(
            filters=32,
            kernel_size=3,
            activation=tf.nn.relu,
        )(inp)
        
        conv1 = tf.keras.layers.Conv2D(
            filters=64,
            kernel_size=3,
            strides=2,
            activation=tf.nn.relu,
        )(conv0)
    
    flatten = tf.keras.layers.Flatten()(conv1)
    
    # fully connected layer.
    if include_kr:
      dense = tf.keras.layers.Dense(units=32, activation="relu",
                                    kernel_regularizer=tf.keras.regularizers.l2(1e-3))(flatten)
      dense = tf.keras.layers.Dense(units=64, activation="relu", 
                                    kernel_regularizer=tf.keras.regularizers.l2(1e-3))(dense)
    else:
      dense = tf.keras.layers.Dense(units=32, activation="relu")(flatten)
      dense = tf.keras.layers.Dense(units=64, activation="relu")(dense)
      
    # output layer.
    dense = tf.keras.layers.Dense(units=10, activation="softmax")(dense)
    
    model = tf.keras.Model(inputs=inp, outputs=dense)
    model.summary()
    return model, inp, dense

In [None]:
# Download MNIST data.
(train_x, train_y), (test_x, test_y) = tf.keras.datasets.mnist.load_data()
train_y = np.eye(10)[train_y]
train_x = np.tile(
    train_x.reshape((train_x.shape[0], 1, -1)), (1, 1, 1)) # (1, n_steps = 1, 1)
test_x = np.tile(
    test_x.reshape((test_x.shape[0], 1, -1)), (1, 30, 1)) # (1, n_steps=30, 1)
train_y = train_y.reshape((train_y.shape[0], 1, -1))

In [None]:
train_x.shape, train_y.shape, test_x.shape, test_y.shape

In [None]:
def get_batch_generator(is_test=True, batch_size=64):
  if is_test:
    for i in range(0, test_x.shape[0], batch_size):
      if i+batch_size > test_x.shape[0]:
        continue
      yield(test_x[i:i+batch_size], test_y[i:i+batch_size])
  else:
    for i in range(0, train_x.shape[0], batch_size):
      if i+batch_size > train_x.shape[0]:
        continue
      input_dict = {
        "input_1": train_x[i:i+batch_size],
        "n_steps": np.ones((batch_size, 1)),
        "conv2d.0.bias": np.ones((batch_size, 32, 1)),
        "conv2d_1.0.bias": np.ones((batch_size, 64, 1)),
        "dense_2.0.bias": np.ones((batch_size, 10, 1)),
      }
      output_dict = {
        "probe": train_y[i:i+batch_size]
      }
      yield(input_dict, output_dict)

# Train the Model in Nengo-DL with ReLU neurons

In [None]:
model, _, _ = get_model()
converter = nengo_dl.Converter(model)
print(converter.net.all_nodes)
print(converter.net.all_ensembles)

In [None]:
with nengo_dl.Simulator(converter.net, minibatch_size=200, seed=0) as sim:
  sim.compile(
    loss=tf.keras.losses.CategoricalCrossentropy(),
    metrics=["accuracy"],
    optimizer=tf.keras.optimizers.Adam(lr=1e-3)
  )
  for epoch in range(4):
    batches = get_batch_generator(is_test=False, batch_size=200)
    sim.fit(batches, epochs=1, steps_per_epoch=300)
  sim.save_params("./keras_to_snn_mnist")

# Test the Model in Nengo-DL with Spiking ReLU neurons

In [None]:
model, inp, otp = get_model()
ndl_model = nengo_dl.Converter(
  model,
  swap_activations={tf.keras.activations.relu: nengo.SpikingRectifiedLinear()},
  scale_firing_rates=20,
  synapse=0.005
)

In [None]:
ndl_mdl_input = ndl_model.inputs[inp]
ndl_mdl_output = ndl_model.outputs[otp]
with ndl_model.net:
  nengo_dl.configure_settings(stateful=False)
  
with nengo_dl.Simulator(ndl_model.net, minibatch_size=100, seed=0) as sim:
  sim.load_params("./keras_to_snn_mnist")
  batches = get_batch_generator(batch_size=100)
  acc = 0
  for batch in batches:
    data = sim.predict_on_batch({ndl_mdl_input: batch[0]})
    for y_true, y_pred in zip(batch[1], data[ndl_mdl_output]):
      if y_true == np.argmax(y_pred[-1]):
        acc +=1 

print("ACC: %s" % (acc/test_y.shape[0]))