The first line
learn_conn = nengo.Connection(state, value, function=lambda x: 0,
learning_rule_type=nengo.PES(learning_rate=1e-4,
pre_tau=tau))
is creating the connection that will approximate the value function (mapping from the Ensemble representing the state to the 1D value
ensemble representing the scalar value for that state). The function is initialized to be all zero. The next three lines
nengo.Connection(reward, learn_conn.learning_rule,
transform=-1, synapse=tau)
nengo.Connection(value, learn_conn.learning_rule,
transform=-0.9, synapse=0.01)
nengo.Connection(value, learn_conn.learning_rule,
transform=1, synapse=tau)
are setting up the standard TD error formula error = reward + discount * V(s') - V(s)
, where the three lines are computing the reward
, discount * V(s')
, and V(s)
terms, in order. (in this case it is reversed to -error = -reward - discount * V(s') + V(s)
, but it’s the same idea).
Note that V(s)
and V(s')
are computed by applying different temporal filters to the value signal (the synapse
argument). We can think of this, roughly speaking, as adding a delay to the signal (that’s not really true, but just for the sake of explanation). So when we take the value signal, and apply different temporal filters (delays), that allows us to approximate the value of that signal from different points in time, which is the key to computing the TD error.