Nengo_dl identity function is rescaled for longer simulations

I’m encountering this behaviour that I’m finding difficult to debug / understand, where longer training simulations (i.e., more training data) of the exact same network effectively rescales the output and degrades performance significantly.

The weird thing about this is the loss reported by nengo_dl is not changing by any significant amount. In fact it seems uncorrelated with the MSE that I calculate offline from the exact same data (see print statements at end of post). The offline MSE mirrors what you would expect by visually inspecting the attached plot, while the loss reported by nengo_dl seems arbitrary.

Maybe I’m misunderstanding some nuances surrounding the optimization hyperparameters / how the MSE is reported / how the parameters are carried over? How can I get consistent performance across different lengths of training time?

Details: I’m trying to learn a function from $\mathbb{R} \mapsto \mathbb{R}$ by using backprop to optimize both the encoders and decoders, with a single layer of sigmoidal units in between (i.e., a standard perceptron). The example function is just the identity (i.e., communication channel). Training and testing uses the same 5 Hz sinusoid in all conditions. RectifiedLinear units produce approximately the same behaviour.

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import nengo
from nengo.utils.numpy import rmse

import nengo_dl
import tensorflow as tf


def go(sim_t,
       n_neurons=100,
       freq=5,
       n_epochs=100):

    with nengo.Network(seed=0) as inner:
        tf_input = nengo.Node(output=np.zeros(1))

        u = nengo.Node(size_in=1)
        x = nengo.Ensemble(n_neurons, 1, neuron_type=nengo.Sigmoid())
        y = nengo.Node(size_in=1)

        nengo.Connection(tf_input, u, synapse=None)
        nengo.Connection(u, x, synapse=None)
        nengo.Connection(x, y, synapse=None)

        tf_output = nengo.Probe(y, synapse=None)

    t = np.arange(0, sim_t, 0.001)
    data_y = np.sin(2*np.pi*freq*t)[:, None]
    data_u = data_y
    inputs = {tf_input: data_u[:, None, :]}
    outputs = {tf_output: data_y[:, None, :]}

    with nengo_dl.Simulator(inner, minibatch_size=100) as sim_train:
        optimizer = tf.train.AdamOptimizer()
        sim_train.train(inputs, outputs, optimizer, n_epochs=n_epochs, objective='mse')
        sim_train.freeze_params(inner)
        loss = sim_train.loss(inputs, outputs, 'mse')
        
    with nengo.Network() as outer:
        test_input = nengo.Node(output=nengo.processes.PresentInput(data_u, sim_train.dt))
        outer.add(inner)
        nengo.Connection(test_input, u, synapse=None)
        test_output = nengo.Probe(y, synapse=None)
        
    with nengo_dl.Simulator(outer) as sim:
        sim.run(sim_t)
        
    return {
        't': sim.trange(), 
        'actual': sim.data[test_output],
        'target': data_y,
        'loss': loss,
    }

try_sim_t = np.linspace(1, 16, 6)
data = []
for sim_t in try_sim_t:
    data.append(go(sim_t=sim_t))

sl = slice(1000)
plt.figure(figsize=(18, 4))
for sim_t, r in zip(try_sim_t, data):
    plt.plot(r['t'][sl], r['actual'][sl],
             label="sim_t=%s (loss=%.4f)" % (sim_t, r['loss']))
    print("sim_t=%s (mse=%s)" % (
        sim_t, rmse(r['target'].squeeze(), r['actual'].squeeze())**2))

# all of the targets are the same (only differs in length)
plt.plot(r['t'][sl], r['target'][sl], ls='--', lw=2, label="Target")

plt.legend()
plt.xlabel("Time (s)")
plt.show()
Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Training finished in 0:00:04 (loss: 0.0000)                                    
Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Simulation finished in 0:00:00                                                 
Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Training finished in 0:00:16 (loss: 0.0248)                                    
Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Simulation finished in 0:00:01                                                 
Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Training finished in 0:00:27 (loss: 0.0000)                                    
Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Simulation finished in 0:00:01                                                 
Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Training finished in 0:00:39 (loss: 0.0000)                                    
Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Simulation finished in 0:00:02                                                 
Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Training finished in 0:00:50 (loss: 0.0733)                                    
Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Simulation finished in 0:00:03                                                 
Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Training finished in 0:01:02 (loss: 0.0197)                                    
Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Simulation finished in 0:00:03                                                 
sim_t=1.0 (mse=0.00038221049835826233)
sim_t=4.0 (mse=0.03710108930405243)
sim_t=7.0 (mse=0.10872908536932545)
sim_t=10.0 (mse=0.5926093741152001)
sim_t=13.0 (mse=1.2567356972009183)
sim_t=16.0 (mse=2.116441248688516)

FYI, this example is a stripped down version of something else that I’m working on. There, I found that a magic rescaling factor of approximately 0.13 significantly improves the accuracy of the trained network (when trained from 30 seconds of data).

Changing minibatch_size=len(t) – which I understand would be standard gradient descent (full batches) – results in the ideal behaviour:

An explanation of why would be nice. Is it over-fitting to mini-batches that have input values all close to each other? Can this behaviour be detected or prevented in some systematic way? Relatedly, the size currently defaults to 1 – which I understand is called “subgradient” or “online” SGD – what is the motivation for this default?

Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Training finished in 0:00:01 (loss: 0.0010)                                    
Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Simulation finished in 0:00:00                                                 
Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Training finished in 0:00:02 (loss: 0.0010)                                    
Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Simulation finished in 0:00:01                                                 
Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Training finished in 0:00:04 (loss: 0.0010)                                    
Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Simulation finished in 0:00:01                                                 
Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Training finished in 0:00:06 (loss: 0.0010)                                    
Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Simulation finished in 0:00:02                                                 
Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Training finished in 0:00:07 (loss: 0.0010)                                    
Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Simulation finished in 0:00:03                                                 
Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Training finished in 0:00:09 (loss: 0.0010)                                    
Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
Construction finished in 0:00:00                                               
Simulation finished in 0:00:03                                                 
sim_t=1.0 (mse=0.0009584487150118902)
sim_t=4.0 (mse=0.0009584402176110339)
sim_t=7.0 (mse=0.0009584383596686595)
sim_t=10.0 (mse=0.0009584407762897002)
sim_t=13.0 (mse=0.0009584469254131748)
sim_t=16.0 (mse=0.0009584209302731781)

A couple different comments.

One general issue to keep in mind is that increasing sim_t is not increasing the amount of training data in this case. Because your input is periodic, increasing the length of time that you evaluate that input over is just adding more, identical training points. So effectively increasing sim_t here is similar to increasing n_epochs. Since this is a relatively simple problem, we wouldn’t really expect to see the performance improve a lot with more epochs (it might even get worse in this case, since we’re adding duplicates into the training data which probably make the minibatches more stochastic). I suspect that your performance is just fluctuating around some baseline, so increasing sim_t isn’t improving the performance (as reported by sim.loss).

Another issue is why the loss values reported by sim.loss don’t seem to be matching up with the plots. This has to do with how you are running things with the outer network. Importantly, there is a learned parameter on the tf_input->u connection (the transform). When you are running sim_train (and computing the loss there), you are running the network including that parameter. When you are running sim (generating the plot data), you are connecting from test_input->u, and bypassing that parameter. So the output you are seeing from sim, and plotting, is not the same output as you are getting when running sim_train, and computing the loss.

When you set minibatch_size=len(t), you’re running much fewer training steps. I suspect that the parameters of the AdamOptimizer are such that this doesn’t really do much to the transform, so you’re basically seeing the performance of the initial parameters. The tf_input->u transform isn’t changing much from the initial value, so bypassing it as outlined above doesn’t really change the behaviour of the network. Setting minibatch=len(t) also means that all the simulations run for the same number of training steps. And since they all have the same training data, and we’re eliminating the stochasticity of minibatching, they end up with exactly the same parameters (which is why they all match).

The default value of minibatch_size is 1 because that is the behaviour in core Nengo (i.e. if you call sim.run(1.0), you are just simulating a single pass through the network). We want nengo_dl to be a drop-in replacement for the core nengo Simulator, so that behaviour should be unchanged. We could, in theory, have the default minibatch size be larger, and then obscure that from the user by throwing away any “extra” data. But that would slow down the simulation in the default case, in a way that is pretty obscure to the user, which doesn’t seem great.

1 Like

Ah, this was indeed the issue. It hadn’t occurred to me that backprop would learn the transform on a passthrough node. Usually this is a free parameter that can be rolled into the downstream encoders/gains/transforms. But I should have realized. I have changed the line from:

        nengo.Connection(tf_input, u, synapse=None)

to:

        bootstrap = nengo.Connection(tf_input, u, synapse=None)
        nengo_dl.configure_settings(trainable=None)
        inner.config[bootstrap].trainable = False

and we now get consistent performance across each data set. Hurray!

The reason I’m doing things with these bootstrapping connections is because I have two sub-networks: one trained using nengo_dl, and another trained in some other fashion. For training I need to simulate them individually. Then their inputs and outputs are mutually coupled for inference. I couldn’t figure out a more elegant way to do this, since it seems all input nodes must have size_in==0.

Point taken, but I also saw these exact same symptoms when the training data was aperiodic. The periodicity was just me trying to make a bare-boned example. And so this behaviour can happen in more general circumstances. This makes complete sense when you consider the above explanation that I was bypassing a learned gain on the input transform. Thanks! :slight_smile:

Another option would be to do

nengo.Connection(test_input, u, synapse=None, transform=bootstrap.transform)

in case you wanted to continue allowing that parameter to be learned. Either way works though!