"""
Training MNIST with PES.

We train a network to classify MNIST using the PES learning rule.
We will need to use nengo_spa, because nengo core will be hard to represent 
high dimensional input and output
"""

import nengo
import nengo_spa as spa
import numpy as np
import tensorflow as tf

from nengo_extras.data import one_hot_from_labels
from nengo.processes import PresentInput

from pdb import set_trace as bp

# Load mnist dataset
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

# Preprocessing
def preprocess(X: np.ndarray) -> np.ndarray:
    """ Convert images to vectors of value range -1,1 """
    assert isinstance(X, np.ndarray), "X is not np.ndarray!"
    X = X.reshape(len(X), -1) / (X.max() - X.min()) * 2 - 1
    return X
    
X_train = preprocess(X_train)
T_train = one_hot_from_labels(y_train)
X_test = preprocess(X_test)
T_test = one_hot_from_labels(y_test)

# Present each data for 2s
input_process = PresentInput(X_train, 2)
label_process = PresentInput(T_train, 2)

model = nengo.Network()

n_neurons = 1000
input_size = X_train.shape[1] # 784
output_size = T_train.shape[1] # 10

def load_data(t, period=2):
    """ Load training data according to simulation time """
    return X_train[int(t // period), :]
    
def load_label(t, period=2):
    return T_train[int(t // period), :]

with model:
    input = nengo.Node(lambda t: load_data(t))
    label = nengo.Node(lambda t: load_label(t))
    
    pre = nengo.Ensemble(
        n_neurons=n_neurons,
        dimensions=input_size,
        intercepts=nengo.dists.Choice([0.1]),
        max_rates=nengo.dists.Choice([100])
    )
    
    nengo.Connection(input, pre)
    
    post = nengo.Ensemble(
        n_neurons=n_neurons,
        dimensions=output_size
    )
    
    conn = nengo.Connection(
        pre, post,
        synapse=0.01,
        eval_points=X_train,
        function=T_train,
        learning_rule_type=nengo.PES()
    )
    
    # Wire up the error signal for learning
    error = nengo.Ensemble(n_neurons=n_neurons, dimensions=output_size)
    nengo.Connection(post, error)
    nengo.Connection(label, error, transform=-1)
    nengo.Connection(error, conn.learning_rule)
    
    # Average error
    avg_error = nengo.Ensemble(n_neurons=10, dimensions=1)
    nengo.Connection(error, avg_error, function=np.mean)