When constructing neural networks in TensorFlow, you're likely familiar with the tf.keras.Sequential model for simple stacks of layers and the Keras Functional API for more complex architectures, such as those with multiple inputs, multiple outputs, or shared layers. PyTorch offers analogous ways to define model architectures, primarily through subclassing torch.nn.Module, which provides maximum flexibility, and torch.nn.Sequential for simpler cases. These PyTorch methods are presented, drawing parallels to their Keras counterparts.
Sequential and PyTorch nn.SequentialFor straightforward models where layers are arranged in a linear sequence, Keras provides the Sequential API.
Sequential API: A Quick RefresherIn Keras, you can define a simple model like this:
import tensorflow as tf
keras_sequential_model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
keras_sequential_model.summary()
This creates a model where data flows through Flatten, then Dense(128), and finally Dense(10).
nn.Sequential ContainerPyTorch offers torch.nn.Sequential as a convenient way to create similar linear stacks of layers. It's a container that passes data sequentially through the modules it holds.
import torch
import torch.nn as nn
pytorch_sequential_model = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 128),
nn.ReLU(),
nn.Linear(128, 10)
# Softmax is often applied as part of the loss function (e.g., nn.CrossEntropyLoss)
# or explicitly in the forward pass if needed for inference.
)
print(pytorch_sequential_model)
Notice a few things:
nn.Flatten() is similar to tf.keras.layers.Flatten.nn.Linear(in_features, out_features) is PyTorch's equivalent of tf.keras.layers.Dense. You need to specify the input features for the first linear layer.nn.ReLU() are often added as separate layers within nn.Sequential.nn.LogSoftmax(dim=1) or nn.Softmax(dim=1) could be added as the final layer if you need the direct output probabilities. However, it's common practice to omit the final softmax if you're using nn.CrossEntropyLoss, as it combines LogSoftmax and NLLLoss.The nn.Sequential container is excellent for prototyping or when your model architecture is a simple feed-forward pipeline. However, for models with more intricate data flows, shared layers, or multiple input/output paths, you'll turn to PyTorch's more versatile approach: subclassing nn.Module.
torch.nn.ModuleWhile Keras uses its Functional API to build complex models like directed acyclic graphs (DAGs), PyTorch's primary method is to create a custom class that inherits from torch.nn.Module. This object-oriented approach is highly flexible and Pythonic, giving you full control over how your model processes data.
The Keras Functional API allows you to define complex models by connecting layers. For instance, a model with a skip connection might look like this:
# Keras Functional API Example (Illustrative)
input_tensor = tf.keras.Input(shape=(28, 28, 1))
x = tf.keras.layers.Conv2D(32, (3,3), activation='relu', padding='same')(input_tensor)
x = tf.keras.layers.MaxPooling2D((2,2))(x)
residual = x # Store for skip connection
x = tf.keras.layers.Conv2D(64, (3,3), activation='relu', padding='same')(x)
x = tf.keras.layers.Conv2D(32, (1,1), activation='relu', padding='same')(x) # Reduce channels
x = tf.keras.layers.Add()([x, residual]) # Add skip connection
x = tf.keras.layers.Flatten()(x)
output_tensor = tf.keras.layers.Dense(10, activation='softmax')(x)
keras_functional_model = tf.keras.Model(inputs=input_tensor, outputs=output_tensor)
# keras_functional_model.summary()
This API is powerful for models that aren't just a simple stack.
torch.nn.ModuleIn PyTorch, you achieve this level of flexibility by defining your model as a class that inherits from nn.Module. This approach involves two main parts:
The __init__(self) method:
super().__init__() first.self.conv1 = nn.Conv2d(...)).nn.Module, meaning PyTorch can track their parameters, move them between devices (CPU/GPU), etc.The forward(self, input_data, ...) method:
__init__.x = self.conv1(input_data)) and can use any Python logic to direct the data flow. This is where PyTorch's "define-by-run" nature comes into play; the computation graph is built dynamically as the forward method executes.forward are your model's inputs, and what it returns are your model's outputs.Let's translate a similar idea of a network with a potential for more complex routing (though we'll keep it simple for illustration) into a PyTorch nn.Module:
import torch
import torch.nn as nn
import torch.nn.functional as F # Common for stateless operations like activation functions
class SimpleCNN(nn.Module):
def __init__(self, num_classes=10):
super(SimpleCNN, self).__init__()
# Define layers
self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)
# (Batch, 1, 28, 28) -> (Batch, 16, 28, 28)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
# (Batch, 16, 28, 28) -> (Batch, 16, 14, 14)
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
# (Batch, 16, 14, 14) -> (Batch, 32, 14, 14)
# After another pooling: (Batch, 32, 7, 7)
# Placeholder for a layer that might be used in a skip connection path
self.conv_skip_adjust = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=1)
# Flattened size: 32 channels * 7x7 image
self.fc1 = nn.Linear(32 * 7 * 7, 128)
self.fc2 = nn.Linear(128, num_classes)
def forward(self, x):
# x shape: (batch_size, 1, 28, 28)
# First conv block
x1_conv = F.relu(self.conv1(x))
x1_pool = self.pool(x1_conv) # (Batch, 16, 14, 14)
# Second conv block
x2_conv = F.relu(self.conv2(x1_pool))
x2_pool = self.pool(x2_conv) # (Batch, 32, 7, 7)
# Example of a simple "skip" or parallel path
# For a true skip, dimensions must match for addition/concatenation.
# Here, we'll just process it and show how it could be combined.
# Let's imagine we wanted to incorporate x1_pool (16 channels) with x2_pool (32 channels).
# We might need to adjust x1_pool's channels.
skip_path = self.conv_skip_adjust(x1_pool) # (Batch, 32, 14, 14)
skip_path_pooled = self.pool(skip_path) # (Batch, 32, 7, 7)
# For demonstration, let's assume we want to add x2_pool and skip_path_pooled
# This requires them to have the same shape.
# combined = x2_pool + skip_path_pooled # This is where a skip connection would happen
# For this example, we'll proceed with x2_pool as the main path
# Flatten the output for the fully connected layers
x_flat = x2_pool.view(-1, 32 * 7 * 7) # .view is like reshape, -1 infers batch size
x_fc1 = F.relu(self.fc1(x_flat))
# No softmax here if using nn.CrossEntropyLoss
output = self.fc2(x_fc1)
return output
# Instantiate the model
pytorch_custom_model = SimpleCNN(num_classes=10)
print(pytorch_custom_model)
# Example of a dummy input tensor
dummy_input = torch.randn(64, 1, 28, 28) # (batch_size, channels, height, width)
output = pytorch_custom_model(dummy_input)
print("Output shape:", output.shape) # Expected: (64, 10)
In this SimpleCNN class:
nn.Conv2d, nn.MaxPool2d, and nn.Linear are defined in __init__.forward method dictates the data flow. We use F.relu from torch.nn.functional for the ReLU activation, which is a common practice for stateless operations. Alternatively, nn.ReLU() could be defined in __init__ and called in forward..view(-1, 32 * 7 * 7) call flattens the tensor before the fully connected layers, similar to tf.keras.layers.Flatten(). The -1 tells PyTorch to infer the batch size.conv_skip_adjust and skip_path lines illustrate where you might define operations for a more complex path, like a skip connection. For a true additive skip connection, x2_pool and skip_path_pooled would need to have identical shapes.nn.Module| Feature | Keras Sequential |
Keras Functional API |
PyTorch nn.Sequential |
PyTorch nn.Module Subclassing |
|---|---|---|---|---|
| Primary Use Case | Simple linear stacks | Models with branches, shared layers | Simple linear stacks | All model types, especially complex |
| Flexibility | Low | High | Low | Very High |
| Model Definition | List of layers | Graph of interconnected layers | Sequence of modules | Python class with __init__ & forward |
| Dynamic Behavior | No (graph defined statically) | No (graph defined statically) | Limited (static sequence) | Yes (Python logic in forward) |
| Debugging | Relatively simple | Can be complex for large graphs | Relatively simple | Easier with Python debuggers |
| Multiple Inputs/Outputs | No | Yes | No | Yes (via forward signature/return) |
| Shared Layers | No (not naturally) | Yes | No (not naturally) | Yes (instantiate once, use multiple times) |
nn.Module Subclassing is the PyTorch StandardWhile nn.Sequential is useful for simple cases, subclassing nn.Module is the idiomatic and most powerful way to build models in PyTorch due to:
forward method is pure Python. You can use if statements, for loops, call other methods of your class, or integrate any Python library to define the computation. This allows for models whose structure can change based on the input data or other conditions during runtime.__init__ and the data flow logic in forward provides a clear and organized structure.forward pass is Python code, you can use standard Python debugging tools (like pdb or print statements) to inspect tensors and understand the model's behavior at any point.forward method to accept multiple arguments and/or return multiple tensors.
class MultiInputOutputModel(nn.Module):
def __init__(self):
super().__init__()
self.layer1_input1 = nn.Linear(10, 20)
self.layer1_input2 = nn.Linear(5, 20)
self.shared_layer = nn.Linear(20, 30)
self.output_branch1 = nn.Linear(30, 1)
self.output_branch2 = nn.Linear(30, 1)
def forward(self, input1, input2):
x1 = F.relu(self.layer1_input1(input1))
x2 = F.relu(self.layer1_input2(input2))
# Example: Concatenate or add processed inputs
merged = x1 + x2 # Assuming shapes match for addition
shared_out = F.relu(self.shared_layer(merged))
out1 = self.output_branch1(shared_out)
out2 = self.output_branch2(shared_out)
return out1, out2
__init__ (e.g., self.shared_conv = nn.Conv2d(...)) and call it multiple times in forward on different inputs (e.g., out_a = self.shared_conv(input_a), out_b = self.shared_conv(input_b)).nn.ModuleLet's consider a model with a simple skip connection to illustrate the structural definition.
Keras Functional API You define inputs, then layers, and explicitly connect them. A skip connection involves taking an earlier tensor and adding or concatenating it with a later tensor.
PyTorch nn.Module Subclass
The forward method directly implements this flow.
A diagram illustrating how a skip connection is expressed in Keras Functional API versus how the logic flows within a PyTorch
nn.Module'sforwardmethod.
In PyTorch, there isn't usually an explicit Input layer like in Keras Functional API. The input shape is implicitly defined by the data you pass to the model during the first forward pass, or explicitly set in the first layer's definition (e.g., in_channels for nn.Conv2d or in_features for nn.Linear).
The transition from Keras's model-building APIs to PyTorch's nn.Module subclassing is a shift towards a more programmatic and explicit way of defining your network's forward pass. This Python-centric approach offers a great deal of power and flexibility, especially for research and models with unconventional architectures. While nn.Sequential provides a Keras-like convenience for simpler models, mastering nn.Module subclassing is fundamental to fully leveraging PyTorch.
Was this section helpful?
nn.Module as the base class and nn.Sequential.nn.Module and related architecture concepts.© 2026 ApX Machine LearningEngineered with