In order to use data generator for sim.fit
, it would be required to specify data for input
, output
, n_steps
, and all the bias nodes. Data for bias nodes can be specified as an array of ones in the shape of the bias and n_steps
would be of shape batchsize x 1
.
Here is an example of sim.fit
with a data generator:
from urllib.request import urlretrieve
import matplotlib.pyplot as plt
import nengo
import numpy as np
import tensorflow as tf
import nengo_dl
seed = 0
np.random.seed(seed)
tf.random.set_seed(seed)
(train_images, train_labels), (test_images, test_labels) = (
tf.keras.datasets.mnist.load_data())
# flatten images and add time dimension
train_images = train_images.reshape((train_images.shape[0], 1, -1))
train_labels = train_labels.reshape((train_labels.shape[0], 1, -1))
test_images = test_images.reshape((test_images.shape[0], 1, -1))
test_labels = test_labels.reshape((test_labels.shape[0], 1, -1))
n_steps = 200
scale_firing_rate = 100
synapse = 0.005
batchsize = 128
inp = tf.keras.Input(shape=(28, 28, 1))
# convolutional layers
conv0 = tf.keras.layers.Conv2D(
filters=32,
kernel_size=3,
activation=tf.nn.relu,
)(inp)
conv1 = tf.keras.layers.Conv2D(
filters=64,
kernel_size=3,
strides=2,
activation=tf.nn.relu,
)(conv0)
# fully connected layer
flatten = tf.keras.layers.Flatten()(conv1)
dense = tf.keras.layers.Dense(units=10)(flatten)
model = tf.keras.Model(inputs=inp, outputs=dense)
model.summary()
converter = nengo_dl.Converter(model)
def get_batches():
for i in range(0,10000,batchsize):
ip = train_images[i:i+batchsize]
label = train_labels[i:i+batchsize]
yield ({'input_1': ip,
"n_steps": np.ones((batchsize, 1), dtype=np.int32),
"conv2d.0.bias":np.ones((batchsize, 32, 1), dtype=np.int32),
"conv2d_1.0.bias":np.ones((batchsize, 64, 1), dtype=np.int32),
"dense.0.bias":np.ones((batchsize, 10, 1), dtype=np.int32)},
{'probe':label})
data_generator = get_batches()
with nengo_dl.Simulator(converter.net, minibatch_size=batchsize) as sim:
# run training
sim.compile(
optimizer=tf.optimizers.RMSprop(0.001),
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.metrics.sparse_categorical_accuracy],
)
sim.fit(
data_generator,
epochs=2,
steps_per_epoch=10,
)