Additional variables in Signals for custom learning rules

Hi community! We are writing a new, custom learning rule based on BCM and PES. This rule needs additional variables to be stored in each synapse, for example a counter for activation times or a memory for custom values. We have a working custom learning rule with minimal functionality, based on the documentation and posts on this forum, like this one.

It implements a class which inherits from LearningRuleType:

class HebbianRewardedLearning(LearningRuleType):
    modifies = "weights"
    probeable = ("error", "pre_filtered", "post_filtered", "delta")

    learning_rate = NumberParam("learning_rate", low=0, readonly=True, default=1e-4)
    pre_synapse = SynapseParam("pre_synapse", default=Lowpass(tau=0.005), readonly=True)
    post_synapse = SynapseParam("post_synapse", default=None, readonly=True)

    def __init__(self,
                 learning_rate=Default,
                 pre_synapse=Default,
                 post_synapse=Default
                 ):
        super().__init__(learning_rate, size_in=1)
        # super().__init__(learning_rate, size_in="post_state")
        self.pre_synapse = pre_synapse
        self.post_synapse = (
            self.pre_synapse if post_synapse is Default else post_synapse
        )
    @property
    def _argreprs(self):
        return _remove_default_post_synapse(super()._argreprs, self.pre_synapse)

We have an operator with a step function:

class SimHebbianRewarded(Operator):

   ... # initialise here variables for our custom rule
    
    def __init__(self, pre_filtered, post_filtered, weights, error, my_var, delta, learning_rate, tag=None):
        super().__init__(tag=tag)

        self.learning_rate = learning_rate

        self.sets = []
        self.incs = [my_var]
        self.reads = [pre_filtered, post_filtered, error, weights]
        self.updates = [delta]

       ... # here are getters and setters

      def get_delta_increase(self, weight: float, my_var: float):
           print(my_var)
            
           # the following lines fail to set the variable or generate errors
           
           # self.my_var[...] = 999
           # Change(self.my_var[...], 999)
           # Change(my_var, 999)
           return np.sign(weight) * my_var * (self.weight_max - np.abs(weight))


    def make_step(self, signals, dt, rng):
        weights = signals[self.weights]
        pre_filtered = signals[self.pre_filtered]
        post_filtered = signals[self.post_filtered]
        error = signals[self.error]
        delta = signals[self.delta]
        my_var = signals[self.my_var]

        def step_simpes():
            # print('step SIM hebbian')
            d = 0
            d = self.get_delta_increase(weights, my_var)

            delta[...] = d

        return step_simpes

and the register function:

@Builder.register(HebbianRewardedLearning)
def build_hebbian_rewarded_learning(model, hebbian, rule):

    conn = rule.connection
    pre_activities = model.sig[get_pre_ens(conn).neurons]["out"]
    post_activities = model.sig[get_post_ens(conn).neurons]["out"]
    pre_filtered = build_or_passthrough(model, hebbian.pre_synapse, pre_activities)
    post_filtered = build_or_passthrough(model, hebbian.post_synapse, post_activities)
    test = model.sig[conn]

    # Create input error signal
    error = Signal(shape=rule.size_in, name="HebbianRewardedLearning:error")
    model.add_op(Reset(error))
    model.sig[rule]["in"] = error  # error connection will attach here

    my_var = Signal(shape=(1,), name="my_var")
    model.add_op(Change(my_var, 13))
    model.sig['my_key']['my_var'] = my_var
    # model.sig[conn]["my_var"] = my_var

    model.add_op(
        SimHebbianRewarded(
            pre_filtered,
            post_filtered,
            model.sig[conn]["weights"],
            error,
            model.sig['my_key']['my_var'],
            model.sig[rule]["delta"],
            hebbian.learning_rate,
            tag='my_learning_rule')
    )

    # expose these for probes
    model.sig[rule]["pre_filtered"] = pre_filtered
    model.sig[rule]["post_filtered"] = post_filtered

I’d like to be able to read and write to a variable within get_delta_increase(…), where I can save an extra state. On the biological level, I imagine that as having a synapse which implements additional processes, which depend on past events. My first trial was adding a new Signal to the model.sig dictionary of dictionaries, but I can, however, not set this Signal during the learning rule step function. I implemented an Operator to do this, but it will not be active during the step function, as it is not hooked into the correct place in model.sig[][]:


class Change(Operator):

    def __init__(self, dst, value=0, tag=None):
        super().__init__(tag=tag)
        self.value = float(value)

        self.sets = []
        self.incs = [dst]
        self.reads = []
        self.updates = []

    @property
    def dst(self):
        return self.incs[0]

    @property
    def _descstr(self):
        return str(self.dst)

    def make_step(self, signals, dt, rng):
        target = signals[self.dst]
        value = self.value

        def step_reset():
            target[...] = value

        return step_reset

To sum it up: I need some extra variables per synapse which I can manipulate during learning, additionally to the weights. Ideally, they would be linked to a Synapse, but they could also just live independently at another place of the data structure.

I would be grateful for any hint about achieving this within the nengo way of thinking and implementing.

It looks like you’re on the right track. It’s not clear to me exactly what problem you’re running into.

A few minor points (which maybe will help):

  • Set your signal on the learing rule like so. This will make sure the signal is associated with the learning rule object:
model.sig[rule]['my_var'] = my_var
  • You should then be able to add “my_var” to the probeable list, which will allow you to probe it and monitor it during simulation.
  • What is the idea behind Change? It looks like it’s just setting the signal to a particular value. If this is what you want, we have a Reset operator for that purpose. The way your Change operator is currently implemented, it looks like a copy of Reset but dst is now marked as an incremented variable, even though it’s not incremented (rather it’s set).
  • Is there some reason to update my_var separately from the SimHebbianRewarded operator? I would probably just update it there. You can update it in the same way that delta is updated; have it as part of the updates list, and in your step function do my_var[...] = new_value.

Thank you very much for the answer and the code proposal!

Here is how I finally solved it. The custom variable is called weight_floor:

In the build function:

@Builder.register(HebbianRewardedLearning)
def build_hebbian_rewarded_learning(model, hebbian, rule):
   ...
   weight_floor = Signal(initial_value=float(0.), name='HebbianRewardedLearning:weight_floor')
   model.sig[rule]['weight_floor'] = weight_floor
   ...
   model.add_op(
      SimHebbianRewarded(
         pre_filtered,
         post_filtered,
         weight_floor=model.sig[rule]['weight_floor'],
         tag='my_learning_rule')
   )

in the Operator class:


class SimHebbianRewarded(Operator):
...
   self.reads = [pre_filtered, post_filtered,
                    reward, weights,
                    pre_activities, post_activities,
                    weight_floor]
...
   @property
   def weight_floor(self):
      return self.reads[6]

  def make_step(self, signals, dt, rng):
     ...
     weight_floor = signals[self.weight_floor]
     def step_simpes():
         ...
         weight_floor[...] += some_value


The Change operator was indeed not necessary.

Adding on the probable list would help, but I cannot set a probe onto a learning rule. The error message states that probes can target only Nengo objects. Is there any fast workaround for that?

Also, I’m wondering whether the new variable has to be specified explicitly in the self.reads list. Why not in the sets list, as I’m both reading and setting it?