import numpy as np
import matplotlib.pyplot as plt

# import nengo_dl
import nengo
from nengo import solvers
from nengo.utils.ensemble import sorted_neurons
from nengo.dists import Uniform
from nengo.utils.matplotlib import rasterplot

# reading the iris dataset in the csv format

data = np.genfromtxt('Iris.csv', delimiter=',', usecols=(1, 2, 3, 4))

# normalization to unity of each pattern in the data

features = np.apply_along_axis(lambda x: x / np.linalg.norm(x), 1, data[1:len(data), :])

# loading the labels

target = np.genfromtxt('Iris.csv',
                       delimiter=',',
                       usecols=[5],
                       dtype=str)

target = target[1:len(target)]
labels = np.unique(target)


class DataFeeder:
    """
    The default value of dt in Nengo is 0.001. This means that a single data-point is seen only 1ms, 1000 times.
    1ms is too small for brain. We want to retain the same datapoint for 50ms.
    It cannot be changed since it is read-only. 50 ms is dt = 0.05 -> n = 1/0.05 = 20.
    self.t = 0.001, window = 20, features.shape[0] = 150

    self.t = 0.000
    int(20*0.000) = int(0,00) = 0
    self.idx = 0% 150 = 0

    self.t = 0.001
    int(20*0.001) = int(0.02) = 0
    self.idx = 0% 150 = 0

    .... so on until

    self.t = 0.05
    int(20*0.05) = int(1) = 1
    self.idx = 1%150 = 1

    """

    def __init__(self, window=20):
        self.t = 0
        self.window = window
        self.idx = 0

    def sepal_length(self, t):
        self.t = t
        self.idx = int(self.window * self.t) % features.shape[0]
        return features[self.idx, 0]

    def sepal_width(self, t):
        self.t = t
        self.idx = int(self.window * self.t) % features.shape[0]
        return features[self.idx, 1]

    def petal_length(self, t):
        self.t = t
        self.idx = int(self.window * self.t) % features.shape[0]
        return features[self.idx, 2]

    def petal_width(self, t):
        self.t = t
        self.idx = int(self.window * self.t) % features.shape[0]
        return features[self.idx, 3]


model = nengo.Network(seed=0)  # The network is created with a seed so that multiple runs should be identical

with model:

    # Node

    """feature_1 = nengo.Node(0)
    feature_2 = nengo.Node(0)
    feature_3 = nengo.Node(0)
    feature_4 = nengo.Node(0)"""

    data_feeder = DataFeeder()
    feature_1 = nengo.Node(output=data_feeder.sepal_length, label="Features 1")
    feature_2 = nengo.Node(output=data_feeder.sepal_width, label="Features 2")
    feature_3 = nengo.Node(output=data_feeder.petal_length, label="Features 3")
    feature_4 = nengo.Node(output=data_feeder.petal_width, label="Features 4")

    # Ensemble

    ens1 = nengo.Ensemble(n_neurons=10, dimensions=1, label="Ensemble 1")
    ens2 = nengo.Ensemble(n_neurons=10, dimensions=1, label="Ensemble 2")

    # Connections

    nengo.Connection(feature_1, ens1)  # , solver=nengo.solvers.LstsqL2(weights=True))
    nengo.Connection(feature_2, ens1)  # , solver=nengo.solvers.LstsqL2(weights=True))
    nengo.Connection(feature_3, ens1)  # , solver=nengo.solvers.LstsqL2(weights=True))
    a = nengo.Connection(feature_4, ens1)  # , solver=nengo.solvers.LstsqL2(weights=True))

    print("learning_rule_type", a.learning_rule_type)
    print("function", a.function)
    print("synapse", a.synapse)
    print("transform", a.transform)
    print("solver", a.solver)
    print("label", a.label)

    conn = nengo.Connection(pre=ens1, post=ens2,
                            solver=nengo.solvers.LstsqL2(weights=True))  # , transform=nengo_dl.dists.Glorot())

    print("encoders in ens1", ens1.encoders)

    # Probe

    feature_1_probe = nengo.Probe(feature_1)
    feature_2_probe = nengo.Probe(feature_2)
    feature_3_probe = nengo.Probe(feature_3)
    feature_4_probe = nengo.Probe(feature_4)

    ens2_probe = nengo.Probe(ens2, synapse=0.005)
    ens2_spikes = nengo.Probe(ens2.neurons)


def neuron_connection(simulation, connection, neuron_idx):

    conn_weights_sig = simulation.signals[simulation.model.sig[connection]["weights"]]
    # print("conn_weights_sig\n", conn_weights_sig)

    np.savetxt("weights_before.txt", conn_weights_sig)

    # conn_weights_sig are read-only, the values can be updated by setting write = True
    conn_weights_sig.setflags(write=True)
    #print("\nSet flags=True\n", conn_weights_sig)

    conn_weights_sig[neuron_idx, :] = 0
    conn_weights_sig.setflags(write=False)

    np.savetxt("weights_after.txt", conn_weights_sig)
    # print("\nSet flags=False\n", conn_weights_sig)
    # print(f"\nWeights of neuron index that are set to zero:\n{neuron_idx} \n", conn_weights_sig[neuron_idx, :])


dt = 1e-3
with nengo.Simulator(model, dt=dt) as sim:
    sim.run(dt * 7500)
    # spike_count = np.sum(sim.data[ens2_spikes] > 0, axis=0)

# print(f'spike counts are {spike_count}')

indices = sorted_neurons(ens2, sim, iterations=250)

plt.subplot(2, 1, 1)
rasterplot(sim.trange(), sim.data[ens2_spikes])  # [:, indices])
plt.xlim(0, 0.5)
plt.title("No Neuron extraction")

with nengo.Simulator(model, dt=dt) as sim:
    neuron_connection(simulation=sim, connection=conn, neuron_idx=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
    sim.run(dt * 7500)
    # spike_count = np.sum(sim.data[ens2_spikes] > 0, axis=0)

# print(f'spike counts are {spike_count}')


"""print("weights", sim.data[conn].weights)  # connection weights
print("bias", sim.data[ens2].bias)  # bias values
print("encoders", sim.data[ens2].encoders)  # encoder values
print("data", sim.data[ens2])  # to see all the parameters for an object"""


plt.subplot(2, 1, 2)
rasterplot(sim.trange(), sim.data[ens2_spikes])  # [:, indices])
plt.xlim(0, 0.5)
plt.title("Neuron extraction")
plt.tight_layout()
plt.show()

"""
plt.plot(sim.trange(), sim.data[feature_1_probe], label="Sepal Length")
plt.plot(sim.trange(), sim.data[feature_2_probe], label="Sepal Width")
plt.plot(sim.trange(), sim.data[feature_3_probe], label="Petal Length")
plt.plot(sim.trange(), sim.data[feature_4_probe], label="Petal Width")

plt.xlabel("Simulation Time-step")
plt.ylabel("Values")

plt.legend()
plt.show()"""

'''
if __name__ == "__main__":
    import nengo_gui
    nengo_gui.GUI(__file__).start()
'''
