import matplotlib.pyplot as plt

import nengo
import nengo_spa as spa


ndim = 64
tau = 0.1


def input_func(t):
    if t < 1:
        return "A"
    else:
        return "0"

with spa.Network() as model:
    spa_input = spa.Transcode(input_func, output_vocab=ndim)

    state1 = spa.State(ndim, subdimensions=1, feedback=1, feedback_synapse=tau,
                       represent_cc_identity=False)
    state2 = spa.State(ndim, subdimensions=2, feedback=1, feedback_synapse=tau,
                       represent_cc_identity=False)
    state16 = spa.State(ndim, subdimensions=16, feedback=1, feedback_synapse=tau,
                        represent_cc_identity=False)
    state64 = spa.State(ndim, subdimensions=64, feedback=1, feedback_synapse=tau,
                        represent_cc_identity=False)

    nengo.Connection(spa_input.output, state1.input, transform=tau, synapse=tau)
    nengo.Connection(spa_input.output, state2.input, transform=tau, synapse=tau)
    nengo.Connection(spa_input.output, state16.input, transform=tau, synapse=tau)
    nengo.Connection(spa_input.output, state64.input, transform=tau, synapse=tau)

    p1 = nengo.Probe(state1.output, synapse=0.05)
    p2 = nengo.Probe(state2.output, synapse=0.05)
    p16 = nengo.Probe(state16.output, synapse=0.05)
    p64 = nengo.Probe(state64.output, synapse=0.05)


with nengo.Simulator(model) as sim:
    sim.run(10)

plt.figure()
plt.plot(sim.trange(), spa.similarity(sim.data[p1], model.vocabs[ndim]),
         label="subdim 1")
plt.plot(sim.trange(), spa.similarity(sim.data[p2], model.vocabs[ndim]),
         label="subdim 2")
plt.plot(sim.trange(), spa.similarity(sim.data[p16], model.vocabs[ndim]),
         label="subdim 16")
plt.plot(sim.trange(), spa.similarity(sim.data[p64], model.vocabs[ndim]),
         label="subdim 64")
plt.legend()
plt.title("Similarity of stored vector w.r.t. to 'A'")
plt.show()
