While the Keras Functional and Sequential APIs offer convenient ways to construct many standard neural network architectures, they can become cumbersome or insufficient when dealing with highly customized or dynamic model behaviors. For situations demanding maximum flexibility and control over the model's forward pass logic, TensorFlow provides the capability to define models by subclassing the tf.keras.Model
class.
This approach treats your model definition like any other Python class. You inherit from tf.keras.Model
and implement the necessary methods to define your model's components and its computation logic. This imperative style gives you complete freedom to implement intricate architectures, conditional logic, or recursive structures directly within Python code.
__init__
and call
When you subclass tf.keras.Model
, there are two fundamental methods you need to implement:
__init__(self, ...)
: The constructor. This is where you define all the layers and sub-modules your model will use. It's important to define layers as attributes of your model instance (e.g., self.my_layer = tf.keras.layers.Dense(...)
) within __init__
. This ensures that Keras can track the layers' variables (weights and biases) automatically. You can also define other constants or attributes needed for your model's logic here.
call(self, inputs, training=None, mask=None)
: The forward pass. This method contains the core logic of your model, defining how inputs are transformed into outputs. You use the layers defined in __init__
here, calling them on the input tensors or intermediate tensors. The inputs
argument receives the input data. The optional training
argument is a boolean indicating whether the model is being run in training or inference mode. This is significant for layers like Dropout
or BatchNormalization
that behave differently in these two modes. It's good practice to include the training
argument in your call
signature and pass it along to any such layers used within the method.
Let's implement a simple Multi-Layer Perceptron (MLP) using the subclassing API to illustrate the concept.
import tensorflow as tf
class SimpleMLP(tf.keras.Model):
def __init__(self, num_units_l1, num_units_l2, num_classes, name="simple_mlp", **kwargs):
super().__init__(name=name, **kwargs)
# Define layers in the constructor
self.dense_layer_1 = tf.keras.layers.Dense(num_units_l1, activation='relu')
self.dropout_layer = tf.keras.layers.Dropout(0.5) # Example layer needing 'training' flag
self.dense_layer_2 = tf.keras.layers.Dense(num_units_l2, activation='relu')
self.output_layer = tf.keras.layers.Dense(num_classes, activation='softmax')
def call(self, inputs, training=None):
# Define the forward pass logic
x = self.dense_layer_1(inputs)
# Pass the 'training' argument to layers that need it
x = self.dropout_layer(x, training=training)
x = self.dense_layer_2(x)
outputs = self.output_layer(x)
return outputs
# Instantiate the model
mlp_model = SimpleMLP(num_units_l1=128, num_units_l2=64, num_classes=10)
# Build the model (optional, happens automatically on first call)
# This step initializes the weights based on the input shape
mlp_model.build(input_shape=(None, 784))
# You can inspect the model
mlp_model.summary()
# Example call (forward pass)
# Create some dummy data (batch size 4, feature size 784)
dummy_input = tf.random.normal([4, 784])
output_tensor = mlp_model(dummy_input, training=False) # Call during inference
print("Output shape:", output_tensor.shape)
In this example:
Dense
, Dropout
) are defined and stored as instance attributes in __init__
.call
method dictates the sequence of operations: input -> dense1 -> dropout -> dense2 -> output.training
argument is explicitly passed to the Dropout
layer within call
. Keras automatically handles providing the correct boolean value to training
when you use model.fit()
, model.evaluate()
, or model.predict()
.The primary motivation for using the subclassing API is flexibility. It allows you to:
if/else
statements, for
loops, or even call external Python functions within your call
method. While tf.function
(which Keras uses automatically) imposes certain constraints on Python code for graph conversion, it handles common control flow constructs effectively (see Chapter 1).call
method, especially during eager execution.Choose subclassing when the Functional API feels restrictive, such as when implementing novel research ideas, models with dynamic behaviors, or architectures where the computation path is determined programmatically during the forward pass.
It's helpful to contrast the subclassing approach with the Keras Functional API:
Feature | Functional API | Subclassing API |
---|---|---|
Definition | Declarative: Define graph structure | Imperative: Define forward pass logic |
Flexibility | Good for static, DAG-like graphs | High for dynamic/complex logic |
Architecture | Explicitly defined by layer connections | Implicitly defined by call method |
Visualization | Easy model plotting (tf.keras.utils.plot_model ) |
Can be harder to visualize static structure |
Serialization | Generally straightforward serialization | SavedModel works well; some complex Python logic might need care |
Debugging | Debug graph construction/execution | Debug Python code in call (often easier) |
The Functional API creates a static graph representation of your model upfront. This graph can be easily inspected, plotted, and reasoned about. The Subclassing API defines the forward pass imperatively through Python code executed in the call
method. While tf.function
compiles this into a graph for performance, the definition itself is more dynamic.
Comparison of defining a model using the Functional API versus the Subclassing API. The Functional API explicitly defines the layer graph, while the Subclassing API defines the forward pass logic within the
call
method, utilizing layers defined in__init__
.
training
ArgumentRemember that the training
argument in call(self, inputs, training=None)
is significant. Keras automatically sets this to True
during model.fit()
and False
during model.evaluate()
or model.predict()
. You must pass this argument to any layer within your call
method that has different behavior during training and inference, such as tf.keras.layers.Dropout
or tf.keras.layers.BatchNormalization
. Forgetting to do this is a common source of errors, leading to models behaving unexpectedly during inference (e.g., dropout still being active).
# Inside the call method of a subclassed model:
def call(self, inputs, training=None):
x = self.conv_layer(inputs)
x = self.batch_norm_layer(x, training=training) # Correct: pass training flag
x = tf.nn.relu(x)
x = self.dropout_layer(x, training=training) # Correct: pass training flag
# ... rest of the forward pass
return outputs
Models created via subclassing can typically be saved and loaded just like Sequential or Functional models using model.save()
and tf.keras.models.load_model()
. This saves the architecture (by inspecting the call
method traced by tf.function
), weights, and optimizer state into the TensorFlow SavedModel
format.
However, be mindful that the serialization relies on TensorFlow's ability to trace the call
method into a graph. If your call
method contains complex Python logic that cannot be easily traced or relies heavily on external Python state, saving and loading might require more careful handling or custom serialization logic. For most standard deep learning operations and control flow, tf.function
and SavedModel
work effectively.
By mastering the subclassing API, you equip yourself with the ability to implement virtually any model architecture or behavior within TensorFlow, moving beyond predefined patterns to create truly custom machine learning solutions.
© 2025 ApX Machine Learning