Thanks Ben,
i did think it could be something to do with the compile as it the only bit that changes really, below is my code.
i might try a few different loss functions and optimisers in the mean time.
with nengo.Network(seed=0) as net:
# copied from example
net.config[nengo.Ensemble].max_rates = nengo.dists.Choice([100])
net.config[nengo.Ensemble].intercepts = nengo.dists.Choice([0])
net.config[nengo.Connection].synapse = None
neuron_type = nengo.LIF(amplitude=0.01)
# example
nengo_dl.configure_settings(stateful=False)
# the input node that will be used to feed in input histogram
inp = nengo.Node(np.zeros(7999))
# I've tried this way and without the neuron_type
hidden = nengo_dl.Layer(
tf.keras.layers.Dense(units=1024, activation=tf.nn.relu))(inp)
hidden = nengo_dl.Layer(neuron_type)(hidden)
hidden = nengo_dl.Layer(
tf.keras.layers.Dense(units=512, activation=tf.nn.relu))(hidden)
hidden = nengo_dl.Layer(neuron_type)(hidden)
hidden = nengo_dl.Layer(
tf.keras.layers.Dense(units=256, activation=tf.nn.relu))(hidden)
hidden = nengo_dl.Layer(neuron_type)(hidden)
out = nengo_dl.Layer(
tf.keras.layers.Dense(units=4096))(hidden)
# we'll create two different output probes, one with a filter
# (for when we're simulating the network over time and
# accumulating spikes), and one without (for when we're
# training the network using a rate-based approximation)
out_p = nengo.Probe(out, label="out_p")
# out_p_filt = nengo.Probe(out, synapse=0.1, label="out_p_filt")
minibatch_size = 200
sim = nengo_dl.Simulator(net, minibatch_size=minibatch_size)
# add single timestep to training data
train_histograms = train_histograms[:, None, :]
train_images = train_images[:, None, :]
# when testing our network with spiking neurons we will need to run it
# over time, so we repeat the input/target data for a number of
# timesteps.
n_steps = 300
test_images = np.tile(test_images[:, None, :], (1, n_steps, 1))
test_histograms = np.tile(test_histograms[:, None, :], (1, n_steps, 1))
# def classification_accuracy(y_true, y_pred):
# return tf.keras.metrics.Accuracy(
# y_true[:, -1], y_pred[:, -1])
#
# note that we use `out_p_filt` when testing (to reduce the spike noise)
# sim.compile(loss={out_p_filt: classification_accuracy})
# sim.compile(optimizer=tf.optimizers.Adam(),
# loss='mse',
# # tf.losses.SparseCategoricalCrossentropy(from_logits=True),tf.losses.CategoricalCrossentropy(from_logits=False, label_smoothing=0),
# metrics=["accuracy"])
# print("accuracy before training:",
# sim.evaluate(test_histograms, {out_p: test_images}, verbose=0)["loss"])
do_training = True
if do_training:
# run training
sim.compile(optimizer=tf.optimizers.Adam(),
loss='mse',
# tf.losses.SparseCategoricalCrossentropy(from_logits=True),tf.losses.CategoricalCrossentropy(from_logits=False, label_smoothing=0),
metrics=["accuracy"])
# sim.compile(
# optimizer=tf.optimizers.RMSprop(0.001),
# loss={out_p: tf.keras.losses.MeanSquaredError()}
# )
sim.fit(train_histograms, {out_p: train_images}, epochs=1000)
# save the parameters to file
sim.save_params("./model_allphotons_{}".format(datetime.datetime.today()))
else:
# download pretrained weights
#urlretrieve(
# "https://drive.google.com/uc?export=download&"
# "id=1l5aivQljFoXzPP5JVccdFXbOYRv3BCJR",
# "mnist_params.npz")
# load parameters
sim.load_params("./model_allphotons_2020-03-17 12:24:15.238096")
# sim.load_params("./model_9500photons_2020-03-17 00:52:19.731343")
# sim.compile(loss={out_p_filt: classification_accuracy})
# print("accuracy after training:",
# sim.evaluate(test_histograms, {out_p: test_images}, verbose=0)["loss"])
data = sim.predict(test_histograms[:minibatch_size])