import matplotlib.pyplot as plt
import nengo
import numpy as np

dim = 10
data = np.random.random(dim) * 2 - 1
simtime = 1

print("Mean:", np.mean(data))
print("Sum:", np.sum((data - np.mean(data)) ** 2) / dim)
print("STD:", np.sqrt(np.sum((data - np.mean(data)) ** 2) / dim), np.std(data))


def input_func(t):
    # Generate a new output vector every 0.25s
    if t % 0.25 == 0:
        input_func.data = np.random.random(dim) * 2 - 1
    return input_func.data


with nengo.Network() as model:
    inp = nengo.Node(input_func)

    std_node = nengo.Node(size_in=dim, output=lambda t, x: np.std(x))
    nengo.Connection(inp, std_node, synapse=None)
    probe_std = nengo.Probe(std_node)

    sqr_ens = nengo.networks.EnsembleArray(100, dim)
    sqr_ens.add_output("square", lambda x: x ** 2)

    nengo.Connection(inp, sqr_ens.input, transform=np.ones((dim, dim)) * -1.0 / dim)
    nengo.Connection(inp, sqr_ens.input)

    sum_ens = nengo.Ensemble(200, 1)
    nengo.Connection(sqr_ens.square, sum_ens, transform=np.ones((1, dim)) * 1.0 / dim)

    out = nengo.Node(size_in=1)
    nengo.Connection(sum_ens, out, function=lambda x: 0 if x < 0 else np.sqrt(x))

    probe_inp = nengo.Probe(inp)
    probe_sum = nengo.Probe(sum_ens, synapse=0.01)
    probe_out = nengo.Probe(out, synapse=0.01)

with nengo.Simulator(model) as sim:
    sim.run(simtime)

plt.figure()
plt.plot(sim.trange(), sim.data[probe_out])
plt.plot(sim.trange(), np.std(sim.data[probe_inp], axis=1), "--")
plt.legend(["Out", "Ref"])
plt.show()
