Loading a `tf.keras.Model` subclass in NengoDL

I have a tf.keras.Model subclass that I want to load into my Nengo network with a TensorNode, as described in this tutorial. I’ve done this successfully with the Keras functional API as shown there. However, now I want to load a model that I defined in a subclass, like:

class MySequentialModel(tf.keras.Model):
    def __init__(self, name=None, **kwargs):
        super().__init__(**kwargs)
        self.dense_1 = FlexibleDense(out_features=3)
        self.dense_2 = FlexibleDense(out_features=2)

    def call(self, x):
        x = self.dense_1(x)
        return self.dense_2(x)

Unfortunately, in my KerasWrapper (from tutorial), the line:

self.model = tf.keras.models.clone_model(self.model)

fails with:
ValueError: Expected `model` argument to be a functional `Model` instance, but got a subclass model instead.
It seems I can’t clone a class-based model. Is there another way to do this?

In the NengoDL example you linked, the purpose of the KerasWrapper class is to build the Keras model within the nengo.Network() context (i.e., within the with nengo.Network() as net block). Because the example uses a Keras functional model, the KerasWrapper class needs to use the clone_model function to “re-build” the model when the KerasWrapper.build method is called.

With the functional model, you could technically re-define the entire Keras model inside the build function (like so):

class KerasWrapper(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()

    def build(self, input_shapes):
        super().build(input_shapes)

        <... copy the functional model definition here ...> 

        # load the weights we saved above
        self.model.load_weights(model_weights)

    def call(self, inputs):
        # apply the model to the inputs
        return self.model(inputs)

But the purpose of the example is to show how you would write the KerasWrapper class without having to do this (i.e., you don’t have to duplicate code).

With the Keras model subclass, the model is created every time you create an instance of the model subclass. Thus, the KerasWrapper class can be modified to just make an instance of the model subclass, rather than using the clone_model function. Assuming you have defined your Keras model subclass as MySequentialModel, the modified KerasWrapper class would look like:

class KerasWrapper(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()

    def build(self, input_shapes):
        super().build(input_shapes)

        # Create Keras model subclass instance
        self.model = MySequentialModel()

        # load the weights we saved above
        self.model.load_weights(model_weights)

    def call(self, inputs):
        # apply the model to the inputs
        return self.model(inputs)

You can then further modify the KerasWrapper class to be more general, by accepting the Keras model subclass class name in the KerasWrapper.__init__() method:

class KerasWrapper(tf.keras.layers.Layer):
    def __init__(self, keras_model_class):
        super().__init__()
        self.model_class = keras_model_class

    def build(self, input_shapes):
        super().build(input_shapes)

        # Create Keras model subclass instance
        self.model = self.model_class()

        # load the weights we saved above
        self.model.load_weights(model_weights)

    def call(self, inputs):
        # apply the model to the inputs
        return self.model(inputs)

I took the code from this section of the “Integrating a Keras Model” example and modified it with the changes I described above. The modified code is here: test_tf_model_subclass.py (5.5 KB)
Feel free to download it and play with it. :smiley:

1 Like