Spiking LMU: Accuracy vs presentation time

Hello Nengo Community,

I am using spiking LMU for a classification task. The classification accuracy is around 85% if the input is presented for a single time step (noSteps = 1). However, when I increase the time steps to 20 (noSteps = 20) the accuracy drops which should be the other way around. As theoretically increasing the number of steps will allow spikes integration over a longer time period and hence should result in an increase in accuracy. Could someone please look at the model and let me know why is this happening? The model is given below. Thank you very much in advance.

noSteps = 1
seed = 0
tf.random.set_seed(seed)
np.random.seed(seed)
rng = np.random.RandomState(seed)

class LMUCell(nengo.Network):
def init(self, units, order, theta, input_d, **kwargs):
super().init(**kwargs)

    # compute the A and B matrices according to the LMU's mathematical derivation
    # (see the paper for details)
    Q = np.arange(order, dtype=np.float64)
    R = (2 * Q + 1)[:, None] / theta
    j, i = np.meshgrid(Q, Q)

    A = np.where(i < j, -1, (-1.0) ** (i - j + 1)) * R
    B = (-1.0) ** Q[:, None] * R
    C = np.ones((1, order))
    D = np.zeros((1,))

    # here we are using zero-order hold(zoh) method to 
    # discretize the value of A and B 
    A, B, _, _, _ = cont2discrete((A, B, C, D), dt=1.0, method="zoh")

    with self:
        nengo_dl.configure_settings(trainable=None)

        # create objects corresponding to the x/u/m/h variables in the above diagram
        self.x = nengo.Node(size_in=input_d)
        self.u = nengo.Node(size_in=1)
        self.m = nengo.Node(size_in=order)
        #self.h = nengo_dl.TensorNode(tf.nn.relu, shape_in=(units,), pass_time=False)
        self.h = nengo.Ensemble(units, 1, 
                                neuron_type=nengo.Tanh(tau_ref=1), 
                                gain=np.ones(units), bias=np.zeros(units)).neurons  

        #compute u_t from the above diagram. we have removed e_h and e_m as they
        #are not needed in this task.
        nengo.Connection(
            self.x,
            self.u,
            transform=np.ones((1, input_d)),
            synapse=None,
        )

        # compute m_t
        # in this implementation we'll make A and B non-trainable, but they
        # could also be optimized in the same way as the other parameters.
        # note that setting synapse=0 (versus synapse=None) adds a one-timestep
        # delay, so we can think of any connections with synapse=0 as representing
        # value_{t-1}.
        conn_A = nengo.Connection(
            self.m,
            self.m,
            transform=A,
            synapse=0,
        )
        self.config[conn_A].trainable = False

        conn_B = nengo.Connection(
            self.u,
            self.m,
            transform=B,
            synapse=None,
        )
        self.config[conn_B].trainable = False

        # compute h_t
        nengo.Connection(
            self.x,
            self.h,
            transform=nengo_dl.dists.Glorot(),
            synapse=None,
        )
        nengo.Connection(
            self.h,
            self.h,
            transform=nengo_dl.dists.Glorot(),
            synapse=0,
        )
        nengo.Connection(
            self.m,
            self.h,
            transform=nengo_dl.dists.Glorot(),
            synapse=None,
        )

with nengo.Network(seed=seed) as net:

# remove some unnecessary features to speed up the training
nengo_dl.configure_settings(
    trainable=None,
    stateful=False,
    keep_history=True,
)

# input node
inp = nengo.Node(np.zeros(x_train.shape[-1]))

# lmu cell
lmu = LMUCell(
    units=64,
    order=256,
    theta=x_train.shape[1],
    input_d=x_train.shape[-1],
)
conn = nengo.Connection(inp, lmu.x, synapse=None)
net.config[conn].trainable = False

# dense linear readout
out = nengo.Node(size_in=15)
nengo.Connection(lmu.h, out, transform=nengo_dl.dists.Glorot(), synapse=None)

p = nengo.Probe(out)

out_p_filt = nengo.Probe(out, label="out_p_filt")

do_training = True
from tensorflow.keras.callbacks import Callback

class TrainingCallback(Callback):
def init(self):
self.Best_validation_Acc = 0

def on_epoch_end(self, epoch, logs=None):

    if self.Best_validation_Acc < logs['val_probe_accuracy']:
        self.Best_validation_Acc = logs['val_probe_accuracy']
        sim.save_params("./LMU_Weights")

with nengo_dl.Simulator(net, minibatch_size=32) as sim:sim.compile(
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.optimizers.Adam(),
metrics=[“accuracy”],
)

def classification_accuracy(y_true, y_pred):
    return tf.metrics.sparse_categorical_accuracy(y_true[:, -1], y_pred[:, -1])

print("Accuracy before training:", sim.evaluate(x_test, y_test, verbose=0)["probe_accuracy"])


if do_training:

    sim.fit(x_train, y_train, verbose=1,epochs=2000,validation_data=(x_test, y_test),callbacks=[TrainingCallback()])

else:
    sim.load_params("./LMU_Weights")

sim.load_params("./LMU_Weights")

x_test_tiled =  np.tile(x_test, (1, noSteps, 1))
y_test_tiled = np.tile(y_test, (1, noSteps, 1))

sim.compile(loss={out_p_filt: classification_accuracy})

acc = sim.evaluate(x_test_tiled, {out_p_filt: y_test_tiled}, verbose=0)["loss"],
print("Accuracy:", acc)