[NengoDL] Signals.scatter() shape issue

I’m working on porting my custom learning rule to NengoDL and feel I’m 90% there. There are just a few issue to iron out, the main one being that signals.scatter() is having trouble dealing with my Tensor shapes.

I have a matrix of resistance values that I would like to read and update at each timestep in my NengoDL learning rule.
The matrix is defined in my build_mpes() build function and has shape (post.n_neurons, pre.n_neurons), basically the same as weights and delta Signals:

memristors = Signal( shape=(out_size, in_size), name="mPES:memristors", (out_size, in_size) ) )

I then get the corresponding TensorSignal in the init() of my NengoDL SimmPESBuilder class and reshape it to (1, post.n_neurons, pre.n_neurons):

    self.memristors = signals.combine( [ op.memristors for op in ops ] )
    self.memristors = self.pos_memristors.reshape(
            (len( ops ), ops[ 0 ].memristors.shape[ 0 ], ops[ 0 ].memristors.shape[ 1 ])
            )

I do the reshape so that when I get a Tensor of shape (1, 1, post.n_neurons, pre.n_neurons) in the build_step() function by using:

memristors = signals.gather( self.memristors )

I then proceed to update memristors using tf.tensor_scatter_nd_update(), but when I try to write back to the TensorSignal by using:

signals.scatter( self.memristors, memristors )

I get the following error (with post.n_neurons = pre.n_neurons = 4):

ValueError: The outer 2 dimensions of indices.shape=[1,4,2] must match the outer 2 dimensions of updates.shape=[1,1,4,4]: Dimension 1 in both shapes must be equal, but are 4 and 1. Shapes are [1,4] and [1,1]. for ‘TensorGraph/while/iteration_0/SimmPESBuilder/TensorScatterUpdate’ (op: ‘TensorScatterUpdate’) with input shapes: [1,4,4], [1,4,2], [1,1,4,4].

I have looked into nengo_dl/signals.py and the issue seems to be caused by line 349: dst_shape[dst.minibatched] = dst.shape[0] (https://github.com/nengo/nengo-dl/blob/72242a2e826c172ddfe57cd5731ef7b6070315a2/nengo_dl/signals.py#L349):

    # align val shape with dst base shape
    val.shape.assert_is_fully_defined()
    dst_shape = list(base_shape)
    dst_shape[dst.minibatched] = dst.shape[0]
    if val.shape != dst_shape:
        val = tf.reshape(val, dst.tf_shape)

Specifically, on line 348: dst_shape = list(base_shape) the assignment is [1,4,4] i.e., self.memristors.shape, but then the next line assigns it [1,1,4] because dst.minibatches is True.
So when we get to line 367: result = tf.tensor_scatter_nd_update(var, dst.tf_indices_nd, val) the dimensions of var and val don’t line up.

So my guess is that I should drop the reshaping of self.memristors in the init() of SimmPESBuilder (https://github.com/Tioz90/MemristorLearning/blob/b6eaf7a5627f25bbb999ea3f3444b3ad68eeac3d/memristor_nengo/learning_rules.py#L279) in order for the TensorSignal to have shape (4, 4) but then I would have to modify all my learning rule logic, which expects all Tensors to have shape (1, 1, 4, 4), exactly as the built-in implementation of PES does (https://github.com/nengo/nengo-dl/blob/72242a2e826c172ddfe57cd5731ef7b6070315a2/nengo_dl/learning_rule_builders.py#L268).

What would your suggestions be?

I made a new branch of my code (https://github.com/Tioz90/MemristorLearning/blob/a8d222857f26457a70a504faebc8019e2e209231/memristor_nengo/learning_rules.py#L270) because I somehow had a feeling that Building a custom learning rule operator for Nengo DL may have helped with my issue.
Why? Because then I can safely remove the reshape() on my memristor TensorSignals needed to add the extra dimension to keep track of the number of `ops .

And, lo and behold, at first glance it seems as if the issue with .scatter() has been resolved without changing anything in build_step()!

I’m not sure if how I enforce the single op in init() is the “proper” way and I would obviously like to keep the operator merging as a possibility, in order to have all the performance possible.
Do you think the behaviour of scatter() can be improved to better deal with a situation like mine? Or is it the case of changing something else in my code?

I think just reshaping memristors back to the correct shape ((4, 4)) before the scatter should work won’t it? Like signals.scatter(self.memristors.reshape((4, 4)), memristors)

Thanks! The code runs perfectly now! Sometimes the simplest solution is the last we think of :slight_smile:

Would operator merging still work correctly with the .reshape() as you suggested?

Yeah, all the merging has already happened before we get to the build_step function, so doing reshapes or anything like that within your build process can’t make things un-mergeable (we just have to make sure that all our reshaping is consistent with the merged data we’re getting as input).

1 Like