Reinforcement learning with image input

Good day,

I am very new to Nengo and RL in general, so this might be a relatively easy question. I am trying to have a neural network place the a white pixel in a 3x3 grid at the bottom center of the grid with reinforcement learning. For instance if the white pixel starts at the top right, the end goal would be to have it be placed at the bottom center like so.

The network that I currently have is very similar to the one in the RL example in NengpoFPGA other than the fact that there are 4 actions and the input is a 9 dimensional Node.

The reward is generated like so. If the white square is in the ideal place, the reward is 1, if the white square is anywhere else, the reward is -0.5 and finally, if the white square is taken out of the frame, the reward is -1 and the white square is randomly placed in the top row of the 3x3 grid for a new epoch to start.

There must be something that I don’t understand in either reward generation or the PES learning rule because the network doesn’t seem to learn at all. The error signal quickly goes to zero and usually, the network seems to stabilize in a non-ideal state, ie not centering the white pixel properly like if it gets stuck in a local minimum. I would love to have your thoughts on this.

I am putting the code below for reference. Don’t hesitate if anything is unclear. Any help and tips and tricks is appreciated.

Also, since this is my first question, any feedback as to how to better ask questions is welcome!

Thanks

import nengo
from nengo_extras.gui import image_display_function
import numpy as np

from environments.event_environment import EventEnvironment
environment = EventEnvironment()

model = nengo.Network()


def init_weights(shape):
    return np.random.uniform(-1e-3, 1e-3, size=shape)


def observe(x):
    obs = environment.get_observation()
    return obs.flatten()


with model:
    #### Visual System Sensory input ####
    #####################################
    # 0    1    2
    # 3    4    5
    # 6    7    8
    #####################################
    visual_stimulus = nengo.Node(observe, label="Visual Sensor")
    
    visual_neurons = nengo.Ensemble(n_neurons=50,
                                    dimensions=9,
                                    label="Visual Neurons")
    
    nengo.Connection(visual_stimulus,
                    visual_neurons)
    
    
with model:
    #### Action selection ####
    def go_forward(x):
        return 0.9
    
    def go_left(x):
        return 0.8
    
    def go_right(x):
        return 0.7

    def stop(x):
        return 0.6
    
    basal_ganglia = nengo.networks.BasalGanglia(4)
    thalamus = nengo.networks.Thalamus(4)
    
    nengo.Connection(basal_ganglia.output,
                    thalamus.input)
    
    forward = nengo.Connection(visual_neurons.neurons,
                            basal_ganglia.input[0],
                            transform=init_weights((1, 50)),
                            learning_rule_type=nengo.PES(learning_rate=1e-4))
    left = nengo.Connection(visual_neurons.neurons,
                            basal_ganglia.input[1],
                            transform=init_weights((1, 50)),
                            learning_rule_type=nengo.PES(learning_rate=1e-4))
    right = nengo.Connection(visual_neurons.neurons,
                            basal_ganglia.input[2],
                            transform=init_weights((1, 50)),
                            learning_rule_type=nengo.PES(learning_rate=1e-4))
    stop = nengo.Connection(visual_neurons.neurons,
                            basal_ganglia.input[3],
                            transform=init_weights((1, 50)),
                            learning_rule_type=nengo.PES(learning_rate=1e-4))
           
           
def perform_action(t, x):
    # Need to add some sort of exploration phase here
    action = np.argmax(x)
    reward = environment.step(action)
    return reward
                    
with model:
    #### Learning ####
    reward_node = nengo.Node(perform_action, size_in=4, label="Reward")
    error = nengo.networks.EnsembleArray(n_neurons=50,
                                            n_ensembles=4,
                                            radius=1,
                                            label="Error")
                                            
    nengo.Connection(thalamus.output, reward_node)
    
    nengo.Connection(reward_node, error.input, transform=-np.ones((4, 1)), synapse=0)
    
    nengo.Connection(basal_ganglia.output[0],
                    error.ensembles[0].neurons,
                    transform=np.ones((50, 1))*4)

    nengo.Connection(basal_ganglia.output[1],
                    error.ensembles[1].neurons,
                    transform=np.ones((50, 1))*4)

    nengo.Connection(basal_ganglia.output[2],
                    error.ensembles[2].neurons,
                    transform=np.ones((50, 1))*4)

    nengo.Connection(basal_ganglia.output[3],
                    error.ensembles[3].neurons,
                    transform=np.ones((50, 1))*4)

    nengo.Connection(basal_ganglia.input, error.input, transform=1, synapse=0)
    
    
    nengo.Connection(error.ensembles[0], forward.learning_rule)
    nengo.Connection(error.ensembles[1], left.learning_rule)
    nengo.Connection(error.ensembles[2], right.learning_rule)
    nengo.Connection(error.ensembles[3], stop.learning_rule)
    
    
    
with model:
    #### Visualizations ####
    input_display_node = nengo.Node(image_display_function((1, 3, 3)),
                                    size_in=visual_stimulus.size_out,
                                    label="Input Display")
    nengo.Connection(visual_stimulus, input_display_node)

Hi Bouboul
I think one way to improve almost any coding question is to include complete details of how to recreate a problem. In this case, runnable code.
For my part, possibly due to my own deficiencies, I cannot find any reference to the class EventEnvironment or the environments library so I cannot make the code work right away. Others might find it obvious, of course! If a quick look is correct, the idea is that it feeds the 3X3 ‘images’ into the program.
Are you able to provide more detail please?