import matplotlib.pyplot as plt
import nengo
import numpy as np

dim = 4
num_values = 100
values = np.random.random((num_values, dim)) - 0.5

for i in range(num_values):
    if np.linalg.norm(values[i]) > 1:
        values[i] = values[i] / np.linalg.norm(values[i])

dt = 0.001

runtime = 210
stop_learn = 10

ref_matrix = np.eye(dim)
np.random.shuffle(ref_matrix)


def cycle_array(x, period=0.2, dt=dt):
    """Cycles through the elements"""
    i_every = int(round(period / dt))
    if i_every != period / dt:
        raise ValueError(f"dt ({dt}) does not divide period ({period})")

    def f(t):
        i = int(round((t - dt) / dt))  # t starts at dt
        return x[int(i / i_every) % num_values]

    return f


def error_func(t, x, shutoff=runtime - stop_learn):
    if t > shutoff:
        return np.zeros(dim)
    else:
        return x


with nengo.Network() as model:
    inp = nengo.Node(cycle_array(values))

    pre = nengo.Ensemble(int(50 * dim ** 1.5), dim)
    nengo.Connection(inp, pre)

    post = nengo.Node(size_in=dim)
    conn = nengo.Connection(
        pre,
        post,
        function=lambda x: np.random.random(dim),
        learning_rule_type=nengo.PES(learning_rate=1e-4),
    )

    error = nengo.Node(error_func, size_in=dim)
    nengo.Connection(post, error)
    nengo.Connection(inp, error, transform=ref_matrix * -1)
    nengo.Connection(error, conn.learning_rule)

    ref = nengo.Node(size_in=dim)
    nengo.Connection(inp, ref, transform=ref_matrix)

    p_in = nengo.Probe(inp)
    p_post = nengo.Probe(post, synapse=0.01)
    p_err = nengo.Probe(error)
    p_ref = nengo.Probe(ref)

with nengo.Simulator(model, dt=dt) as sim:
    sim.run(runtime)

plt.figure(figsize=(16, 10))
plt.subplot(411)
plt.plot(sim.trange(), sim.data[p_in])
plt.subplot(412)
plt.plot(sim.trange(), sim.data[p_post])
plt.subplot(413)
plt.plot(sim.trange(), sim.data[p_err])
plt.subplot(414)
plt.plot(sim.trange(), sim.data[p_ref])
plt.tight_layout()

plot_last = 10
plot_ind = int(plot_last / dt)

plt.figure(figsize=(14, 8))
colors = ["r", "b", "g", "y", "k", "c", "m"]

plt.subplot(311)
plt.title("Input")
for i in range(dim):
    plt.plot(sim.trange()[-plot_ind:], sim.data[p_in][-plot_ind:, i], colors[i])
plt.subplot(312)
plt.title("Learned output vs Expected output")
for i in range(dim):
    plt.plot(sim.trange()[-plot_ind:], sim.data[p_post][-plot_ind:, i], colors[i])
    plt.plot(sim.trange()[-plot_ind:], sim.data[p_ref][-plot_ind:, i], colors[i] + "--")
plt.legend(["Output", "Expected"])
plt.subplot(313)
plt.plot(
    sim.trange()[-plot_ind:],
    sim.data[p_post][-plot_ind:, :] - sim.data[p_ref][-plot_ind:, :],
)
plt.title("Difference between learned output and expected value")
plt.tight_layout()
plt.show()
