[NengoDL] Issue with running custom neuron on GPU

I have written a custom neuron model for NengoDL that throws the following error when run on a machine with a GPU:

Build finished in 0:00:00                                                      
Optimization finished in 0:00:00                                               
| #                       Constructing graph                          | 0:00:002021-01-13 11:29:01.448067: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:39] Overriding allow_growth setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
2021-01-13 11:29:01.449333: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:39] Overriding allow_growth setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
2021-01-13 11:29:01.450732: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:39] Overriding allow_growth setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
2021-01-13 11:29:01.452037: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:39] Overriding allow_growth setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
|                         Constructing graph   #                      | 0:00:04
Traceback (most recent call last):
  File "mnist.py", line 138, in <module>
    with nengo_dl.Simulator( model ) as sim_train:
  File "/home/p291020/.conda/envs/nengodl/lib/python3.8/site-packages/nengo_dl/simulator.py", line 526, in __init__
    self._build_keras(progress)
  File "/home/p291020/.conda/envs/nengodl/lib/python3.8/site-packages/nengo/utils/magic.py", line 181, in __call__
    return self.wrapper(self.__wrapped__, self.instance, args, kwargs)
  File "/home/p291020/.conda/envs/nengodl/lib/python3.8/site-packages/nengo_dl/simulator.py", line 50, in with_self
    output = wrapped(*args, **kwargs)
  File "/home/p291020/.conda/envs/nengodl/lib/python3.8/site-packages/nengo_dl/simulator.py", line 549, in _build_keras
    outputs = self.tensor_graph(inputs, stateful=self.stateful, progress=progress)
  File "/home/p291020/.conda/envs/nengodl/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 925, in __call__
    return self._functional_construction_call(inputs, args, kwargs,
  File "/home/p291020/.conda/envs/nengodl/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py", line 1117, in _functional_construction_call
    outputs = call_fn(cast_inputs, *args, **kwargs)
  File "/home/p291020/.conda/envs/nengodl/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 302, in wrapper
    return func(*args, **kwargs)
  File "/home/p291020/.conda/envs/nengodl/lib/python3.8/site-packages/nengo_dl/tensor_graph.py", line 488, in call
    self._build_loop(sub) if self.use_loop else self._build_no_loop(sub)
  File "/home/p291020/.conda/envs/nengodl/lib/python3.8/site-packages/nengo_dl/tensor_graph.py", line 638, in _build_loop
    loop_vars = tf.while_loop(
  File "/home/p291020/.conda/envs/nengodl/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 574, in new_func
    return func(*args, **kwargs)
  File "/home/p291020/.conda/envs/nengodl/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2489, in while_loop_v2
    return while_loop(
  File "/home/p291020/.conda/envs/nengodl/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2687, in while_loop
    return while_v2.while_loop(
  File "/home/p291020/.conda/envs/nengodl/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 188, in while_loop
    body_graph = func_graph_module.func_graph_from_py_func(
  File "/home/p291020/.conda/envs/nengodl/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/p291020/.conda/envs/nengodl/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 174, in wrapped_body
    outputs = body(*_pack_sequence_as(orig_loop_vars, args))
  File "/home/p291020/.conda/envs/nengodl/lib/python3.8/site-packages/nengo_dl/tensor_graph.py", line 615, in loop_body
    loop_i = self._build_inner_loop(loop_i, update_probes, progress)
  File "/home/p291020/.conda/envs/nengodl/lib/python3.8/site-packages/nengo_dl/tensor_graph.py", line 771, in _build_inner_loop
    side_effects = self.op_builder.build_step(self.signals, progress)
  File "/home/p291020/.conda/envs/nengodl/lib/python3.8/site-packages/nengo_dl/builder.py", line 101, in build_step
    output = self.op_builds[ops].build_step(signals)
  File "/home/p291020/.conda/envs/nengodl/lib/python3.8/site-packages/nengo_dl/neuron_builders.py", line 535, in build_step
    self.built_neurons.build_step(signals)
  File "/home/p291020/.conda/envs/nengodl/lib/python3.8/site-packages/nengo_dl/neuron_builders.py", line 191, in build_step
    smart_cond.smart_cond(
  File "/home/p291020/.conda/envs/nengodl/lib/python3.8/site-packages/tensorflow/python/framework/smart_cond.py", line 58, in smart_cond
    return control_flow_ops.cond(pred, true_fn=true_fn, false_fn=false_fn,
  File "/home/p291020/.conda/envs/nengodl/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py", line 201, in wrapper
    return target(*args, **kwargs)
  File "/home/p291020/.conda/envs/nengodl/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/p291020/.conda/envs/nengodl/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 1180, in cond
    return cond_v2.cond_v2(pred, true_fn, false_fn, name)
  File "/home/p291020/.conda/envs/nengodl/lib/python3.8/site-packages/tensorflow/python/ops/cond_v2.py", line 95, in cond_v2
    return _build_cond(
  File "/home/p291020/.conda/envs/nengodl/lib/python3.8/site-packages/tensorflow/python/ops/cond_v2.py", line 221, in _build_cond
    _make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph])
  File "/home/p291020/.conda/envs/nengodl/lib/python3.8/site-packages/tensorflow/python/ops/cond_v2.py", line 634, in _make_indexed_slices_indices_types_match
    assert len(set(outs_per_branch)) == 1, outs_per_branch
AssertionError: [5, 3]

I thought the issue may have been that I had included a tf.cond() command in my model, but removing it still leads to the same identical error.

I am running:

tensorflow==2.3.0
nengo-dl==3.4.0
nengo==3.1.0

I have no issue running it on the CPU on my local machine that uses:

tensorflow==2.4.0
nengo-dl==3.4.0
nengo==3.1.0

Could it be an issue with tensorflow 2.3.0?

My implementation can be found here.

Thanks in advance for any help!

Sounds like it’s not a problem in your code if it works on the CPU, most likely it’s some odd bug in TensorFlow. I’d try using the same TensorFlow version in both cases, just to see if it’s unique to a particular version.

You might also be able to avoid the issue by adding tf.compat.v1.disable_control_flow_v2() at the top of your script. Perhaps the older control flow V1 implementation in TensorFlow won’t have this bug.

I can confirm that moving to tensorflow==2.4.0 has fixed the issue, but you may still want to look into it, as nengodl==3.4.0 should be compatible with tensorflow==2.3.0 :slight_smile:

See also here: https://github.com/nengo/nengo-dl/issues/198

I would double-check that things actually work in TensorFlow 2.4. Given that I’m pretty sure you’re forgetting to return those states (as mentioned in the Github issue), I feel like it’s just silently failing in TF 2.4.

I had been going crazy because I’d noticed that my neuron had a different behaviour in Nengo Core and in NengoDL and this was exactly the problem: I’d overlooked returning my extra states!

For completeness’ sake I tested the original code (only returning 3/5 states) in TensorFlow 2.3 and found that it does not throw an error when running on the CPU on my local machine.