When constructing neural networks in TensorFlow, you're likely familiar with the tf.keras.Sequential
model for simple stacks of layers and the Keras Functional API for more complex architectures, such as those with multiple inputs, multiple outputs, or shared layers. PyTorch offers analogous ways to define model architectures, primarily through subclassing torch.nn.Module
, which provides maximum flexibility, and torch.nn.Sequential
for simpler cases. This section will guide you through these PyTorch approaches, drawing parallels to their Keras counterparts.
Sequential
and PyTorch nn.Sequential
For straightforward models where layers are arranged in a linear sequence, Keras provides the Sequential
API.
Sequential
API: A Quick RefresherIn Keras, you can define a simple model like this:
import tensorflow as tf
keras_sequential_model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
keras_sequential_model.summary()
This creates a model where data flows through Flatten
, then Dense(128)
, and finally Dense(10)
.
nn.Sequential
ContainerPyTorch offers torch.nn.Sequential
as a convenient way to create similar linear stacks of layers. It's a container that passes data sequentially through the modules it holds.
import torch
import torch.nn as nn
pytorch_sequential_model = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 128),
nn.ReLU(),
nn.Linear(128, 10)
# Softmax is often applied as part of the loss function (e.g., nn.CrossEntropyLoss)
# or explicitly in the forward pass if needed for inference.
)
print(pytorch_sequential_model)
Notice a few things:
nn.Flatten()
is similar to tf.keras.layers.Flatten
.nn.Linear(in_features, out_features)
is PyTorch's equivalent of tf.keras.layers.Dense
. You need to specify the input features for the first linear layer.nn.ReLU()
are often added as separate layers within nn.Sequential
.nn.LogSoftmax(dim=1)
or nn.Softmax(dim=1)
could be added as the final layer if you need the direct output probabilities. However, it's common practice to omit the final softmax if you're using nn.CrossEntropyLoss
, as it combines LogSoftmax
and NLLLoss
.The nn.Sequential
container is excellent for prototyping or when your model architecture is a simple feed-forward pipeline. However, for models with more intricate data flows, shared layers, or multiple input/output paths, you'll turn to PyTorch's more versatile approach: subclassing nn.Module
.
torch.nn.Module
While Keras uses its Functional API to build complex models like directed acyclic graphs (DAGs), PyTorch's primary method is to create a custom class that inherits from torch.nn.Module
. This object-oriented approach is highly flexible and Pythonic, giving you full control over how your model processes data.
The Keras Functional API allows you to define complex models by connecting layers. For instance, a model with a skip connection might look like this:
# Keras Functional API Example (Illustrative)
input_tensor = tf.keras.Input(shape=(28, 28, 1))
x = tf.keras.layers.Conv2D(32, (3,3), activation='relu', padding='same')(input_tensor)
x = tf.keras.layers.MaxPooling2D((2,2))(x)
residual = x # Store for skip connection
x = tf.keras.layers.Conv2D(64, (3,3), activation='relu', padding='same')(x)
x = tf.keras.layers.Conv2D(32, (1,1), activation='relu', padding='same')(x) # Reduce channels
x = tf.keras.layers.Add()([x, residual]) # Add skip connection
x = tf.keras.layers.Flatten()(x)
output_tensor = tf.keras.layers.Dense(10, activation='softmax')(x)
keras_functional_model = tf.keras.Model(inputs=input_tensor, outputs=output_tensor)
# keras_functional_model.summary()
This API is powerful for models that aren't just a simple stack.
torch.nn.Module
In PyTorch, you achieve this level of flexibility by defining your model as a class that inherits from nn.Module
. This approach involves two main parts:
The __init__(self)
method:
super().__init__()
first.self.conv1 = nn.Conv2d(...)
).nn.Module
, meaning PyTorch can track their parameters, move them between devices (CPU/GPU), etc.The forward(self, input_data, ...)
method:
__init__
.x = self.conv1(input_data)
) and can use any Python logic to direct the data flow. This is where PyTorch's "define-by-run" nature comes into play; the computation graph is built dynamically as the forward
method executes.forward
are your model's inputs, and what it returns are your model's outputs.Let's translate a similar idea of a network with a potential for more complex routing (though we'll keep it simple for illustration) into a PyTorch nn.Module
:
import torch
import torch.nn as nn
import torch.nn.functional as F # Common for stateless operations like activation functions
class SimpleCNN(nn.Module):
def __init__(self, num_classes=10):
super(SimpleCNN, self).__init__()
# Define layers
self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)
# (Batch, 1, 28, 28) -> (Batch, 16, 28, 28)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
# (Batch, 16, 28, 28) -> (Batch, 16, 14, 14)
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
# (Batch, 16, 14, 14) -> (Batch, 32, 14, 14)
# After another pooling: (Batch, 32, 7, 7)
# Placeholder for a layer that might be used in a skip connection path
self.conv_skip_adjust = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=1)
# Flattened size: 32 channels * 7x7 image
self.fc1 = nn.Linear(32 * 7 * 7, 128)
self.fc2 = nn.Linear(128, num_classes)
def forward(self, x):
# x shape: (batch_size, 1, 28, 28)
# First conv block
x1_conv = F.relu(self.conv1(x))
x1_pool = self.pool(x1_conv) # (Batch, 16, 14, 14)
# Second conv block
x2_conv = F.relu(self.conv2(x1_pool))
x2_pool = self.pool(x2_conv) # (Batch, 32, 7, 7)
# Example of a simple "skip" or parallel path
# For a true skip, dimensions must match for addition/concatenation.
# Here, we'll just process it and show how it could be combined.
# Let's imagine we wanted to incorporate x1_pool (16 channels) with x2_pool (32 channels).
# We might need to adjust x1_pool's channels.
skip_path = self.conv_skip_adjust(x1_pool) # (Batch, 32, 14, 14)
skip_path_pooled = self.pool(skip_path) # (Batch, 32, 7, 7)
# For demonstration, let's assume we want to add x2_pool and skip_path_pooled
# This requires them to have the same shape.
# combined = x2_pool + skip_path_pooled # This is where a skip connection would happen
# For this example, we'll proceed with x2_pool as the main path
# Flatten the output for the fully connected layers
x_flat = x2_pool.view(-1, 32 * 7 * 7) # .view is like reshape, -1 infers batch size
x_fc1 = F.relu(self.fc1(x_flat))
# No softmax here if using nn.CrossEntropyLoss
output = self.fc2(x_fc1)
return output
# Instantiate the model
pytorch_custom_model = SimpleCNN(num_classes=10)
print(pytorch_custom_model)
# Example of a dummy input tensor
dummy_input = torch.randn(64, 1, 28, 28) # (batch_size, channels, height, width)
output = pytorch_custom_model(dummy_input)
print("Output shape:", output.shape) # Expected: (64, 10)
In this SimpleCNN
class:
nn.Conv2d
, nn.MaxPool2d
, and nn.Linear
are defined in __init__
.forward
method dictates the data flow. We use F.relu
from torch.nn.functional
for the ReLU activation, which is a common practice for stateless operations. Alternatively, nn.ReLU()
could be defined in __init__
and called in forward
..view(-1, 32 * 7 * 7)
call flattens the tensor before the fully connected layers, similar to tf.keras.layers.Flatten()
. The -1
tells PyTorch to infer the batch size.conv_skip_adjust
and skip_path
lines illustrate where you might define operations for a more complex path, like a skip connection. For a true additive skip connection, x2_pool
and skip_path_pooled
would need to have identical shapes.nn.Module
Feature | Keras Sequential |
Keras Functional API |
PyTorch nn.Sequential |
PyTorch nn.Module Subclassing |
---|---|---|---|---|
Primary Use Case | Simple linear stacks | Models with branches, shared layers | Simple linear stacks | All model types, especially complex |
Flexibility | Low | High | Low | Very High |
Model Definition | List of layers | Graph of interconnected layers | Sequence of modules | Python class with __init__ & forward |
Dynamic Behavior | No (graph defined statically) | No (graph defined statically) | Limited (static sequence) | Yes (Python logic in forward ) |
Debugging | Relatively simple | Can be complex for large graphs | Relatively simple | Easier with Python debuggers |
Multiple Inputs/Outputs | No | Yes | No | Yes (via forward signature/return) |
Shared Layers | No (not naturally) | Yes | No (not naturally) | Yes (instantiate once, use multiple times) |
nn.Module
Subclassing is the PyTorch StandardWhile nn.Sequential
is useful for simple cases, subclassing nn.Module
is the idiomatic and most powerful way to build models in PyTorch due to:
forward
method is pure Python. You can use if
statements, for
loops, call other methods of your class, or integrate any Python library to define the computation. This allows for models whose structure can change based on the input data or other conditions during runtime.__init__
and the data flow logic in forward
provides a clear and organized structure.forward
pass is Python code, you can use standard Python debugging tools (like pdb
or print statements) to inspect tensors and understand the model's behavior at any point.forward
method to accept multiple arguments and/or return multiple tensors.
class MultiInputOutputModel(nn.Module):
def __init__(self):
super().__init__()
self.layer1_input1 = nn.Linear(10, 20)
self.layer1_input2 = nn.Linear(5, 20)
self.shared_layer = nn.Linear(20, 30)
self.output_branch1 = nn.Linear(30, 1)
self.output_branch2 = nn.Linear(30, 1)
def forward(self, input1, input2):
x1 = F.relu(self.layer1_input1(input1))
x2 = F.relu(self.layer1_input2(input2))
# Example: Concatenate or add processed inputs
merged = x1 + x2 # Assuming shapes match for addition
shared_out = F.relu(self.shared_layer(merged))
out1 = self.output_branch1(shared_out)
out2 = self.output_branch2(shared_out)
return out1, out2
__init__
(e.g., self.shared_conv = nn.Conv2d(...)
) and call it multiple times in forward
on different inputs (e.g., out_a = self.shared_conv(input_a)
, out_b = self.shared_conv(input_b)
).nn.Module
Let's consider a model with a simple skip connection to illustrate the structural definition.
Keras Functional API You define inputs, then layers, and explicitly connect them. A skip connection involves taking an earlier tensor and adding or concatenating it with a later tensor.
PyTorch nn.Module
Subclass
The forward
method directly implements this flow.
A diagram illustrating how a skip connection is expressed in Keras Functional API versus how the logic flows within a PyTorch
nn.Module
'sforward
method.
In PyTorch, there isn't usually an explicit Input
layer like in Keras Functional API. The input shape is implicitly defined by the data you pass to the model during the first forward pass, or explicitly set in the first layer's definition (e.g., in_channels
for nn.Conv2d
or in_features
for nn.Linear
).
The transition from Keras's model-building APIs to PyTorch's nn.Module
subclassing is a shift towards a more programmatic and explicit way of defining your network's forward pass. This Python-centric approach offers a great deal of power and flexibility, especially for research and models with unconventional architectures. While nn.Sequential
provides a Keras-like convenience for simpler models, mastering nn.Module
subclassing is fundamental to fully leveraging PyTorch.
© 2025 ApX Machine Learning