Saving (trained) connections for later use, an overview

Hi all,

I thought It might be a nice to have an overview of how to use the connections from a previously trained network to create one that functions the same. There are multiple post and comments on the forum that touch upon this, but it is still a bit vague to me.


Decoded connection
lets make a network that contains two ensembles called pre and post, both with n neurons and dimension d. They are connected using a decoded connection. (I believe that there are no learning rules that adjust both encoders and decoders at the same time, but it helps to keep this post a bit shorter.)

pre = nengo.Ensemble(n, d, seed=seed)
post = nengo.Ensemble(n, d, seed=seed)

# a decoded connection, the Nengo default
conn = nengo.Connection(pre, post, solver=LstsqL2(weights=False), seed=seed)

Decoders and post ensemble encoders are probed like this:

probe_dec = nengo.Probe(conn, "weights")
probe_enc = nengo.Probe(post, "scaled_encoders")

After doing whatever simulation and learning you want to you access the information inside the probes

trained_decoders = sim.data[probe_dec][-1] # probe for the last time point/measurement
trained_scaled_encoders = sim.data[probe_enc][-1]

We cannot, however, just use these values to plug into our new network that we want to create using the connection from the old network. The scaled encoders need to be “unscaled” first. The decoders we can use as is. NOTE: IS THIS CORRECT?

trained_encoders = trained_scaled_encoders * post.radius / sim.model.params[post].gain[:, None]

Now we are ready to construct a new network witht he exact same properties as the one above.

# New model made to copy the model initially build
pre = nengo.Ensemble(n, d, seed=seed)
post = nengo.Ensemble(n, d, encoder=trained_encoders, seed=seed)

# a decoded connection, the Nengo default
conn = nengo.Connection(pre.neurons, post, transform=trained_decoders, seed=seed)


Direct ensemble-to-ensemble connection
Lets makea network that contains two ensembles called pre and post, both with n neurons and dimension d. They are connected using a direct connection, so with weights.

pre = nengo.Ensemble(n, d, seed=seed)
post = nengo.Ensemble(n, d, seed=seed)

# a direct ensemble-to-ensemble connection
conn = nengo.Connection(pre, post, solver=LstsqL2(weights=True), seed=seed)

The connection weights are probes like this:

probe_weights = nengo.Probe(conn, "weights") # Same argument as for the encoders of the decoded connection, but we are probing something different!

After doing whatever simulation and learning you want to you access the information inside the probes

trained_weights = sim.data[probe_weights][-1] # probe for the last time point/measurement

We cannot, however, just use this values to plug into our new network that we want to create using the connection from the old network. The weights need to be divided by the post neurons gain, as it is multiplied by it in the building process of the new network (NOTE: where in the nengo code does this happen? I cant find it.)

trained_weights = trained_weights / sim.model.params[post].gain[:, None]

Now we are ready to construct a new network witht he exact same properties as the one above.

# New model made to copy the model initially build
pre = nengo.Ensemble(n, d, seed=seed)
post = nengo.Ensemble(n, d,, seed=seed)

# a decoded connection, the Nengo default
conn = nengo.Connection(pre.neurons, post.neurons, transform=trained_weights, seed=seed)

question

Is this correct? I am mostly unsure about what is needed to do after reading out the probes, so the “transformation” of the probed weights/encoders/decoders.

That is correct. Nengo combines the encoders, radius and gains together to get the “scaled encoders”. This is done to reduce the amount of weights that have to be stored.

I tested your code and have several notes:

  • If you are not modifying the encoders of post, then simply setting the seed should be sufficient to create an ensemble that has the same parameters as the initial network.
  • If you are modifying the encoders of post, then in addition to setting the encoder values, you’ll also need to set the gain and bias values to match the values from the original network, like so:
post_new = nengo.Ensemble(
    n, d, encoders=trained_encoders, 
    gain=sim.data[post].gain, bias=sim.data[post].bias, 
    seed=seed)
  • Creating a connection from the pre.neurons object is technically correct, however, there is another method for creating a connection with the trained decoders that preserve the original style of connection from an ensemble. This alternative method is to use the NoSolver solver and provide it the decoder weights from above, like so:
conn = nengo.Connection(
    pre, post, solver=nengo.solvers.NoSolver(trained_weights.T), seed=seed)

This is correct. The gains are combined with the encoders here in the code, and it only applies to connections where the post object is a nengo.Ensemble. Note that for connections to neuron objects, the neuron gains are multiplicatively added here in the code (it’s a different logic path in the connection builder code). This different ways in how the gains are handled between ensemble connections and neuron connections does cause an issue when saving and loading weights (see this github issue), so we are aware of this. I did attempt a fix, but there are other issues I’m not fully considering that is causing the tests to fail (and I do not currently have the time to address those other issues)

Note that the compensation for the gains (i.e., removing them from the weights) need only be done if your original connection is to an ensemble. If the original connection is to a neurons object you do not have to compensate for the gains. E.g.,

# Original network
pre = nengo.Ensemble(n, d, seed=seed)
post = nengo.Ensemble(n, d, seed=seed)

# a direct neuron-to-neuron connection
conn = nengo.Connection(pre.neurons, post.neurons, transform=np.random.random(...), seed=seed)
...
probe_weights = nengo.Probe(conn, "weights")
...
trained_weights = sim.data[probe_weights]

Loaded network:

pre = nengo.Ensemble(n, d, seed=seed)
post = nengo.Ensemble(n, d, seed=seed)

# a direct neuron-to-neuron connection
conn = nengo.Connection(pre.neurons, post.neurons, transform=trained_weights , seed=seed)

Regarding loading the weights in a new network, the NoSolver method does not work with full (neuron-to-neuron) connections. Instead you will need to do what you have done (i.e., create a connection between the neurons and manually set the weights with the transform parameter.