import pandas as pd
import numpy as np
import pyldl.mapping as pmap
import nengo
#%matplotlib inline
import matplotlib.pyplot as plt
#from scipy import spatial
from nengo.dists import Gaussian, Uniform,DistOrArrayParam
from nengo.params import NdarrayParam

#from wikipedia2vec import Wikipedia2Vec


# Create dummy data

c_mat = np.random.uniform(low=0.0, high=1.0, size=(4859,182))
s_mat = np.random.uniform(low=0.0, high=1.0, size=(4859,100))
duration = np.random.uniform(low=0.0, high=0.5, size=(4859,1))

# Create Custom Process


class FlexibleDuration(nengo.Process):
    
    inputs = NdarrayParam("inputs", shape=("...",))
    presentation_time = NdarrayParam("presentation_time", shape=("...",))
    
    
    def __init__(self, inputs, presentation_time, timer, index, internal_t, **kwargs):
        self.inputs = inputs
        self.presentation_time = presentation_time
        self.timer = timer
        self.index = index
        self.internal_t = internal_t
        super().__init__(
            default_size_in=0, default_size_out=self.inputs[0].size, **kwargs
        )

    def make_step(self, shape_in, shape_out, dt, rng, state):
        assert shape_in == (0,)
        assert shape_out == (self.inputs[0].size,)
        
        n = len(self.inputs)
        inputs = self.inputs.reshape(n, -1)
        
        n1 = len(self.inputs)
        presentation_time = self.presentation_time.reshape(n1, -1)
        index = int(self.index)
        timer = int(self.timer)
        internal_t = int(self.internal_t)
        
        def step_presentinput(t):
            self.internal_t += dt
            #print(t,self.internal_t, self.timer)
            if self.internal_t > self.timer:
                self.index += 1
                
                if self.index == len(self.inputs):
                    self.index = 0 
                    print('problem here')
                    
                self.timer += self.presentation_time[self.index]
                print('Increment values ', self.internal_t, self.index, self.timer)
            #print(inputs[self.index])
            print('Process, Index: ', self.index,'Time: ', self.internal_t,'Increment: ', self.timer)
            return inputs[self.index]

        return step_presentinput
    
# Define process that loops over the input and target

c_timer = duration[0]
c_index = 0
c_internal_t = 0 

s_timer = duration[0]
s_index = 0
s_internal_t = 0 

process_in = FlexibleDuration(c_mat, presentation_time=duration,timer=c_timer,index=c_index,internal_t=c_internal_t)
process_target = FlexibleDuration(s_mat, presentation_time=duration,timer=s_timer,index=s_index,internal_t=s_internal_t)

# Init model

n_in = c_mat.shape[1]
n_out = s_mat.shape[1]

intercept_vals = np.random.uniform(low=0.0, high=0.9, size=10000)
encoder_vals = np.random.uniform(low=0.0, high=1.0, size=(10000,n_in))
eval_vals = np.random.uniform(low=0.0, high=1.0, size=(10000,n_in))


# Create and run model

model = nengo.Network()

with model:
    
    inp = nengo.Node(process_in)
    target = nengo.Node(process_target)
    
    inp_ens = nengo.Ensemble(n_neurons = 10000, dimensions = n_in,radius=1,
                            intercepts = intercept_vals,
                            encoders = encoder_vals,
                            eval_points = eval_vals,
                            neuron_type = nengo.LIF())
    nengo.Connection(inp,inp_ens)
    
    
    out = nengo.Node(None,size_in=n_out)
    learn_con = nengo.Connection(inp_ens.neurons,out,transform=np.zeros((100,10000)),
                                 synapse=0.0005,
                                 learning_rule_type=nengo.PES(learning_rate=0.001))
    
    
    error = nengo.Node(None,size_in=n_out)
    
    
    nengo.Connection(out,error)
    nengo.Connection(target,error,transform=-1)
    nengo.Connection(error,learn_con.learning_rule)
    
    
    # For output generation and visualization
    
    #inp_probe = nengo.Probe(inp)
    target_probe = nengo.Probe(target)
    #ens_spikes = nengo.Probe(inp_ens.neurons)
    #ens_probe  = nengo.Probe(inp_ens.neurons, synapse=0.01)
    out_probe = nengo.Probe(out)
    error_probe = nengo.Probe(error, synapse=0.03)
    
    
with nengo.Simulator(model) as sim:
    sim.run(1)

    
# Plot first 10 dimensions of target

plt.rcParams['figure.figsize'] = [20, 50]

for i in range(10):

    plt.subplot(10, 1, i+1)
    plt.plot(sim.trange(), sim.data[target_probe].T[i], c="k", label="Target")
    #plt.plot(sim.trange(), sim.data[out_probe].T[i], c="r", label="Output")
    plt.ylabel('Dimension %d' %i)
    plt.legend(loc="best")
