hi drausmuss
So I tried to replicate similar idea which you explained above but used gym cartpole environment for testing purpose.
import nengo
import numpy as np
import gym
model=nengo.Network()
env = gym.make('CartPole-v0').env
class EnvironmentInterface(object):
def __init__(self,env,stepSize =5):
self.env = env
self.n_actions = env.action_space.n
self.state_dim = env.observation_space.shape[0]
self.t=0
self.stepsize = stepSize
self.output = np.zeros(self.n_actions)
self.state = env.reset()
self.reward= 0
self.current_action = 0
def take_action(self,action):
self.state,self.reward,self.done,_=env.step(action)
if self.done:
# self.reward = -2
self.state = env.reset()
def get_reward(self,t):
return self.reward
def sensor(self,t):
return self.state
def step(self,t,x):
if int(t*1000)%self.stepsize == 0:
self.current_action = np.argmax(x) #np.argmax(self.output)
self.take_action(self.current_action)
def calculate_Q(self,t,x):
if int(t*1000) % self.stepsize == 1:
qmax = x[np.argmax(x)]
op = np.zeros(self.n_actions)
op[self.current_action] = 0.9*qmax + self.reward
self.output = op
return self.output
def step2(self,t,x):
if int(t*1000) == 1:
print("STARTING")
if int(t * 1000)%self.stepsize == 0:
qs = self.output[5:]
self.current_action = np.argmax(qs)
self.take_action(self.current_action)
elif int(t * 1000) % self.stepsize == 1:
qvals = x
qmax = qvals[np.argmax(qvals)]
c_output = np.zeros(self.n_actions)
c_output[self.current_action] = qvals[self.current_action]
f_output = np.zeros(self.n_actions)
f_output[self.current_action] = 0.9*qmax + self.reward
self.output = np.concatenate((c_output,f_output,qvals))
return self.output
tau = 0.01
fast_tau = 0
slow_tau = 0.01
n_action =2
envI=EnvironmentInterface(env)
state_dimensions=envI.state_dim
n_actions = envI.n_actions
with model:
sensor = nengo.Node(envI.sensor)
reward = nengo.Node(envI.get_reward)
sensor_net = nengo.Ensemble(n_neurons=1000,dimensions=envI.state_dim,radius=4)
nengo.Connection(sensor,sensor_net)
action_net = nengo.Ensemble(n_neurons=1000,dimensions=envI.n_actions,radius=4)
learning_conn=nengo.Connection(sensor_net,action_net,function=lambda x:[0,0],learning_rule_type=nengo.PES(1e-3, pre_tau=slow_tau),synapse=tau)
q_node = nengo.Node(envI.calculate_Q,size_in=2,size_out=2)
step_node = nengo.Node(envI.step,size_in=2)
nengo.Connection(action_net,step_node,synapse=fast_tau)
nengo.Connection(action_net,q_node,synapse=tau)
nengo.Connection(q_node,learning_conn.learning_rule,transform =1,synapse=fast_tau) ##0.9*Q(s',a')+r
nengo.Connection(action_net,learning_conn.learning_rule,transform =-1,synapse=slow_tau)#Q(s,a)
But looks like my implementation has some flaw as Q value of one action becomes very high and cartpole iteration never reached beyond few steps.
Is my implementation of action value function fine?