In the NengoDL example you linked, the purpose of the KerasWrapper
class is to build the Keras model within the nengo.Network()
context (i.e., within the with nengo.Network() as net
block). Because the example uses a Keras functional model, the KerasWrapper
class needs to use the clone_model
function to “re-build” the model when the KerasWrapper.build
method is called.
With the functional model, you could technically re-define the entire Keras model inside the build
function (like so):
class KerasWrapper(tf.keras.layers.Layer):
def __init__(self):
super().__init__()
def build(self, input_shapes):
super().build(input_shapes)
<... copy the functional model definition here ...>
# load the weights we saved above
self.model.load_weights(model_weights)
def call(self, inputs):
# apply the model to the inputs
return self.model(inputs)
But the purpose of the example is to show how you would write the KerasWrapper
class without having to do this (i.e., you don’t have to duplicate code).
With the Keras model subclass, the model is created every time you create an instance of the model subclass. Thus, the KerasWrapper
class can be modified to just make an instance of the model subclass, rather than using the clone_model
function. Assuming you have defined your Keras model subclass as MySequentialModel
, the modified KerasWrapper
class would look like:
class KerasWrapper(tf.keras.layers.Layer):
def __init__(self):
super().__init__()
def build(self, input_shapes):
super().build(input_shapes)
# Create Keras model subclass instance
self.model = MySequentialModel()
# load the weights we saved above
self.model.load_weights(model_weights)
def call(self, inputs):
# apply the model to the inputs
return self.model(inputs)
You can then further modify the KerasWrapper
class to be more general, by accepting the Keras model subclass class name in the KerasWrapper.__init__()
method:
class KerasWrapper(tf.keras.layers.Layer):
def __init__(self, keras_model_class):
super().__init__()
self.model_class = keras_model_class
def build(self, input_shapes):
super().build(input_shapes)
# Create Keras model subclass instance
self.model = self.model_class()
# load the weights we saved above
self.model.load_weights(model_weights)
def call(self, inputs):
# apply the model to the inputs
return self.model(inputs)
I took the code from this section of the “Integrating a Keras Model” example and modified it with the changes I described above. The modified code is here: test_tf_model_subclass.py (5.5 KB)
Feel free to download it and play with it.