Can't learn dynamics using nengo.Ensemble

Good time of day,

I’m trying to learn dynamical systems that saturate value and when it goes to 0 resets. This should be a continuous process. I tried to use more neurons and play with synapse and radius, but it doesn’t work. Am I missing something?

Here is snipped that reproduce behavior

import nengo
import numpy as np
import matplotlib.pyplot as plt

tau = 0.01

def feedback(x):
    x = x + 0.0005
    if x > 1:
        return 0
    return x

model = nengo.Network(seed=42)
with model:
    state = nengo.Ensemble(2000, 1, radius=np.sqrt(2))
    nengo.Connection(state, state, function=feedback, synapse=tau)
    state_probe = nengo.Probe(state, synapse=tau)

with nengo.Simulator(model) as sim:
    sim.run(5)

samples = len(sim.data[state_probe])
x = sim.data[state_probe][0]
points = [x]
for i in range(samples-1):
    x = feedback(x)
    points.append(x)

points = np.array(points)

plt.plot(range(samples), sim.data[state_probe], label="x1", alpha=0.8)
plt.plot(range(samples), points, color="r")
plt.show()

Resulted plot


Basically I want blue line match red line

Thank you for any input

Hi @Crol, and welcome to the Nengo forums!

Unfortunately, the function you are trying to reproduce is a little difficult to actual implement in one single neural ensemble. There are a few ways to implement the sawtoothed function you want in Nengo.

The first way is if you are familiar with the NEF (Neural Engineering Framework). If you are able to define the saw-tooth function as a dynamical system, you can use Principle 3 to convert the dynamical systems equation into a neural implementation.

An alternative approach is to break down the sawtooth function into component parts. You will notice that the ramp part of the sawtooth is basically the output of a neural integrator with a constant input (i.e., integrating a constant input will give you a ramp). Thus, you can use a neural integrator to get the ramp part of the sawtooth. Next, you will need some way to detect when you want the ramp to reset. This means you will require another neural ensemble to detect when the output of the neural integrator has reached a certain value. Finally, you’ll need a way to reset the output of the neural integrator back to 0 when the detection ensemble is trigged. That is another neural ensemble that you will need to figure out how to implement.

While it is not exactly what you want, I describe a network that does something similar to what you desire in my PhD thesis (Figure 4.6 & 4.7). The output of my circuit is a sawtooth, but it has slight pauses between the end of one sawtooth and the start of the next:

The code for this network can be found embedded in a larger network here. I leave it up to you to explore my code, but feel free to ask me any questions about it. I’d advise you to start by implementing the base networks first (integrator, threshold detector, pulse generator, etc.) and making sure that those work. Once the individual parts work, you can then start putting them together into a larger network.

Thank you @xchoo

Very good instruction and insights
I was able to make use of make_thresh_ens_net to detect a value from which to reset,
but signal from this network is very week if you reset the source of it’s signal. So that’s why you are amplifying this value as I understand, and then inhibit weights for ramp.
My question would be how to make it work for a new circle?

import nengo
from nengo.dists import Uniform, Choice, Exponential
import numpy as np
import matplotlib.pyplot as plt


def make_thresh_ens_net(threshold=0.5, thresh_func=lambda x: 1,
                        exp_scale=None, num_ens=1, net=None, **args):
    if net is None:
        label_str = args.get('label', 'Threshold_Ens_Net')
        net = nengo.Network(label=label_str)
    if exp_scale is None:
        exp_scale = (1 - threshold) / 10.0

    with net:
        ens_args = dict(args)
        ens_args['n_neurons'] = 200
        ens_args['dimensions'] = 1
        ens_args['intercepts'] = \
            Exponential(scale=exp_scale, shift=threshold,
                        high=1)
        ens_args['encoders'] = Choice([[1]])
        ens_args['eval_points'] = Uniform(min(threshold + 0.1, 1.0), 1.1)
        ens_args['n_eval_points'] = 5000

        net.input = nengo.Node(size_in=num_ens)
        net.output = nengo.Node(size_in=num_ens)

        for i in range(num_ens):
            thresh_ens = nengo.Ensemble(**ens_args)
            nengo.Connection(net.input[i], thresh_ens, synapse=None)
            nengo.Connection(thresh_ens, net.output[i],
                             function=thresh_func, synapse=None)
    return net

tau = 0.1

radius=np.sqrt(3)

def grouth_equesions(x):
    return x + 1 * tau

model = nengo.Network(seed=42)
with model: 
    state = nengo.Ensemble(1000, 1, radius=radius)
    nengo.Connection(state, state, function=grouth_equesions,  synapse=tau)
    
    
    ramp_reset = make_thresh_ens_net(0.07, thresh_func=lambda x: x)
    nengo.Connection(ramp_reset.output,
                     ramp_reset.input)
    
    thresh = make_thresh_ens_net(0, radius=1.1)
    
    nengo.Connection(thresh.output, ramp_reset.input,
                         transform=5.0, synapse=0.015)
                         
    nengo.Connection(state, thresh.input, function= lambda x: x-1)
    
    nengo.Connection(ramp_reset.output, state.neurons,
                         transform=[[-5]] * state.n_neurons)

Example

You generally have the right idea, and your network is very close to working! :smiley:
The three components you need: the integrator that produces the ramp, the threshold detector, and the reset signal generator are there, but there is one key feature of the reset signal generator that you need to implement.

If you look at the output of your reset signal you see that it quickly goes up to about 1.2, but then it stays there. Compared to the plot I included in my first reply, you’ll notice that my ramp signal quickly goes up, but then after some time, starts to decay back to 0. Introducing this decay is very simple, but first, let us understand why the reset signal goes to ~1.2 and stays there.

In your code, this code implements the reset signal:

ramp_reset = make_thresh_ens_net(0.07, thresh_func=lambda x: x)
nengo.Connection(ramp_reset.output, ramp_reset.input)

If we look at the feedback connection, we see that it is using the default transform value, which is 1. Thus, any output of the ramp_reset network gets fed back into to the input of ramp_reset and it is able to maintain a value even after the thresh input goes to 0. This is why the output of ramp_reset quickly goes to ~1.2, but then stays there.

To make the ramp_reset signal decay back to 0 after some time, all we have to do is ensure that the feedback connection does not feed the full output signal back to the input. This is done by applying a transform value that is less than 1 on the feedback connection. For your network, I found that a value of 0.955 works, but you can experiment with it on your own to find a value suitable for you.

ramp_reset = make_thresh_ens_net(0.07, thresh_func=lambda x: x)
nengo.Connection(ramp_reset.output, ramp_reset.input, transform=0.955)

Once you do this, you will see that the ramp signal decays back down to 0 after some time:
image

And, when we look at this in the context of the whole network, we see that the ramp signal resets, but the cycle repeats! :smiley:

Great I understand you just added decay to second thresh for it to stop influencing ramp
But how to make it work without pause? 95% decay is pretty slow.
I think I could figure it out with some combinations of thresholds.
Maybe the better question would be how to reliably set the state to zero?
As for faster decay beginning of a state curve after reset not following by deferential question
image

Hmm. Yeah. Unfortunately, with this method, there is no real way to speed up the gap between each cycle. This is because the network needs time to remove all of the “memory” stored in the integrator, and it’s kind of difficult to speed it up. It is possible you could speed it up by reducing the feedback time constant in the ramp integrator, but you run the risk of making the ramp integrator more unstable:

nengo.Connection(state, state, function=grouth_equesions,  synapse=0.05)

I gave this problem some thought, and I have come up with another approach to solve this problem. Since you are looking to generate a cyclic signal, my idea is two fold:

  1. Create a self-sustaining neural oscillator (an example is here).
  2. Use the output of the neural oscillator, and compute the sawtooth function from that. If the oscillator provides the x, y coordinates of a point moving around a unit circle, we can use the arctan2 function to convert that to an angle, which can then be converted into the sawtooth signal:
sawtooth = 0.5 + np.arctan2(x1, x0) / (2 * np.pi)  # where x0 and x1 are points on a unit circle

Putting it together (note that the oscillator in the code below is a little more complex to try and maintain a radius as close to 1 as possible), we end up with this code:

def recurrent_func(x, tau=0.1):
    x0, x1 = x
    r = np.sqrt(x0 ** 2 + x1 ** 2)
    a = np.arctan2(x1, x0)
    dr = -(r - 2)
    da = 3.0
    r = r + tau * dr
    a = a + tau * da
    return [r * np.cos(a), r * np.sin(a)]


with nengo.Network() as model:
    osc = nengo.Ensemble(1000, 2)
    nengo.Connection(
        osc,
        osc,
        function=recurrent_func,
        synapse=0.1,
    )

    ramp = nengo.Ensemble(100, 1)
    nengo.Connection(
        osc, ramp, function=lambda x: 0.5 + np.arctan2(x[1], x[0]) / (2 * np.pi)
    )

If you simulate this, I get something like this:

This transition between each cycle is not as sharp as in the previous code (with the reset), but that is because the neurons are not good at sharp transitions. There is also a bit of irregularity in the output of the sawtooth (it’s not as clean as before), but I believe this can be remedied by changing the eval_points of the osc ensemble to optimize the output specifically for points on / near the unit circle. The default is compute the decoders for a random selection of points within the unit circle. I’ll leave that up to you to experiment with, however.