Masterclass
Following the self-attention mechanism within each Transformer block, we introduce a component known as the Position-wise Feed-Forward Network (FFN). This network is applied independently and identically at each position in the sequence. While the self-attention layers allow tokens to interact with each other, the FFN processes the representation for each token separately, providing additional non-linear transformation capabilities to the model.
Think of it as adding further computational depth after the context mixing performed by the attention layer. It helps the model learn more complex functions of the features derived from the attention output for each position.
The FFN is typically a simple two-layer fully connected network. For an input representation x at a specific position, the transformation is defined as:
FFN(x)=Linear2(Activation(Linear1(x)))+xWait, the equation above includes the residual connection. Let's break down the core FFN part first. The most common structure consists of:
The formula for the FFN operation itself (before the residual connection which is handled in the block structure) is:
FFN(x)=max(0,xW1+b1)W2+b2Where:
The inner dimension dff is usually larger than dmodel. A common choice is dff=4×dmodel, as used in the original "Attention Is All You Need" paper. This expansion allows the model to potentially learn richer representations before projecting back to the original model dimension.
Let's implement this FFN component as a PyTorch nn.Module
. We will include dropout, which is often applied after the second linear layer within the FFN or as part of the residual connection step.
import torch
import torch.nn as nn
class PositionWiseFeedForward(nn.Module):
"""Implements the Position-wise Feed-Forward Network (FFN) module."""
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
"""
Initializes the PositionWiseFeedForward module.
Args:
d_model (int): The dimensionality of the input and output
features.
d_ff (int): The dimensionality of the inner layer.
dropout (float): The dropout probability. Defaults to 0.1.
"""
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.activation = nn.ReLU()
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the FFN module.
Args:
x (torch.Tensor): Input tensor of shape
(batch_size, seq_len, d_model).
Returns:
torch.Tensor: Output tensor of shape
(batch_size, seq_len, d_model).
"""
# Apply the first linear layer, then activation, then dropout,
# then second linear layer
# x shape: (batch_size, seq_len, d_model)
x = self.linear1(x) # -> (batch_size, seq_len, d_ff)
x = self.activation(x) # -> (batch_size, seq_len, d_ff)
# Dropout can sometimes be placed after the activation or after the
# second linear layer
# We place it after the second linear layer here, consistent with
# some practices.
x = self.linear2(x) # -> (batch_size, seq_len, d_model)
x = self.dropout(x) # -> (batch_size, seq_len, d_model)
return x
Let's test this module with some sample dimensions:
# Example usage:
d_model = 512 # Model dimension
d_ff = 2048 # Inner dimension (often 4 * d_model)
dropout_rate = 0.1
batch_size = 4
seq_len = 10
# Create a sample input tensor
input_tensor = torch.randn(batch_size, seq_len, d_model)
# Instantiate the FFN layer
ffn_layer = PositionWiseFeedForward(d_model, d_ff, dropout_rate)
# Pass the input through the FFN layer
output_tensor = ffn_layer(input_tensor)
print(f"Input shape: {input_tensor.shape}")
print(f"Output shape: {output_tensor.shape}")
# Verify output dimension matches d_model
assert output_tensor.shape == (batch_size, seq_len, d_model)
This code defines the PositionWiseFeedForward
class. The __init__
method sets up the two linear layers (self.linear1
, self.linear2
), the ReLU activation (self.activation
), and the dropout layer (self.dropout
). The forward
method defines the computation flow: input x goes through the first linear layer, then ReLU activation, then the second linear layer, and finally dropout.
Notice that the operations linear1
, activation
, and linear2
are applied independently to the representation at each sequence position. The layers share weights across positions within a single forward pass, but the calculation for position i
does not directly depend on the calculation for position j
within this FFN module (unlike the attention mechanism).
This FFN module is a fundamental building block that we will integrate into the larger Encoder and Decoder layers in the subsequent sections. Its role is to provide non-linear processing capacity applied uniformly across all sequence positions after the context aggregation performed by the multi-head attention sub-layer.
© 2025 ApX Machine Learning