import matplotlib.pyplot as plt
import nengo
import numpy as np

with nengo.Network() as model:
    inp = nengo.Node(lambda t: t)
    out = nengo.Node(size_in=1)

    ens = nengo.Ensemble(
        200,
        1,
        # intercepts=nengo.dists.Uniform(-0.15, 1),
        intercepts=nengo.dists.Exponential(0.02, -0.15),
        encoders=nengo.dists.Choice([[1]]),
    )
    nengo.Connection(inp, ens)
    nengo.Connection(ens, out, function=lambda x: 0 if x < 0 else np.sqrt(x))

    p_inp = nengo.Probe(inp)
    p_out = nengo.Probe(out, synapse=0.005)

with nengo.Simulator(model) as sim:
    sim.run(1)

plt.figure()
plt.plot(sim.trange(), sim.data[p_out])
plt.plot(sim.trange(), np.sqrt(sim.data[p_inp]), "--")
plt.show()
