Masterclass
Now that we have implementations for the core Multi-Head Attention and Position-wise Feed-Forward Network components, we can assemble them into the standard building blocks of the Transformer: the Encoder Layer and the Decoder Layer. These layers incorporate not only the attention and feed-forward sub-networks but also residual connections and layer normalization, which are essential for training deep Transformer models effectively.
An encoder layer processes the input sequence, refining its representations through self-attention and a feed-forward network. Each encoder layer has two primary sub-layers:
Crucially, a residual connection followed by layer normalization is applied after each sub-layer. This structure facilitates gradient flow and stabilizes activations, preventing issues like vanishing or exploding gradients in deep stacks of layers.
The data flow through a single encoder layer can be visualized as follows:
Data flow within a Transformer Encoder Layer. Dashed lines indicate residual connections.
Let's implement this in PyTorch. We'll assume MultiHeadAttention
and PositionwiseFeedForward
are already defined classes based on the previous sections' implementations.
import torch
import torch.nn as nn
import copy
# Assume MultiHeadAttention and PositionwiseFeedForward classes are defined elsewhere
# from .attention import MultiHeadAttention
# from .feed_forward import PositionwiseFeedForward
class EncoderLayer(nn.Module):
"""
Represents one layer of the Transformer Encoder.
It consists of a multi-head self-attention mechanism followed by a
position-wise fully connected feed-forward network. Residual connections
and layer normalization are applied after each sub-layer.
"""
def __init__(
self,
d_model: int,
num_heads: int,
d_ff: int,
dropout: float = 0.1
):
"""
Args:
d_model: The dimensionality of the input and output
(embedding dimension).
num_heads: The number of attention heads.
d_ff: The dimensionality of the inner layer of the
feed-forward network.
dropout: The dropout rate.
"""
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout=dropout)
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout=dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor | None = None
) -> torch.Tensor:
"""
Pass the input through the encoder layer.
Args:
x: The input tensor to the layer (batch_size, seq_len, d_model).
mask: The mask for the self-attention mechanism (optional).
Typically used to ignore padding tokens.
Shape (batch_size, 1, seq_len) or
(batch_size, seq_len, seq_len).
Returns:
The output tensor from the layer (batch_size, seq_len, d_model).
"""
# 1. Multi-Head Self-Attention
attn_output, _ = self.self_attn(query=x, key=x, value=x, mask=mask)
# Apply residual connection and layer normalization
x = self.norm1(x + self.dropout(attn_output)) # Add -> Norm
# 2. Position-wise Feed-Forward Network
ff_output = self.feed_forward(x)
# Apply residual connection and layer normalization
x = self.norm2(x + self.dropout(ff_output)) # Add -> Norm
return x
In the forward
method, note the pattern: the input x
is added to the output of the sub-layer (after dropout) before passing through layer normalization (self.norm1
or self.norm2
). This is the "Add & Norm" step crucial for stable training. The mask
argument is passed to the self-attention layer to prevent attention to padding tokens, if necessary.
The decoder layer shares similarities with the encoder layer but has an additional sub-layer to handle information from the encoder output. Each decoder layer has three main sub-layers:
Like the encoder, each of these three sub-layers is followed by a residual connection and layer normalization.
Here's the data flow for a decoder layer:
Data flow within a Transformer Decoder Layer. Dashed lines indicate residual connections. The Encoder Output (Memory) provides Keys and Values for the Cross-Attention sub-layer.
Now, let's implement the DecoderLayer
in PyTorch.
import torch
import torch.nn as nn
import copy
# Assume MultiHeadAttention and PositionwiseFeedForward classes are defined elsewhere
# from .attention import MultiHeadAttention
# from .feed_forward import PositionwiseFeedForward
class DecoderLayer(nn.Module):
"""
Represents one layer of the Transformer Decoder.
It consists of masked self-attention, cross-attention (attending to
encoder output), and a position-wise feed-forward network. Residual
connections and layer normalization are applied after each sub-layer.
"""
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
"""
Args:
d_model: The dimensionality of the input and output.
num_heads: The number of attention heads.
d_ff: The dimensionality of the inner layer of the
feed-forward network.
dropout: The dropout rate.
"""
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout=dropout)
self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout=dropout)
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout=dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self,
x: torch.Tensor,
memory: torch.Tensor,
src_mask: torch.Tensor | None = None,
tgt_mask: torch.Tensor | None = None) -> torch.Tensor:
"""
Pass the input through the decoder layer.
Args:
x: The input tensor to the decoder layer
(batch_size, tgt_seq_len, d_model).
memory: The output tensor from the encoder stack
(batch_size, src_seq_len, d_model).
src_mask: Mask for the cross-attention
(encoder-decoder attention) layer, used to ignore
padding tokens in the encoder output (optional).
Shape (batch_size, 1, src_seq_len).
tgt_mask: Mask for the masked self-attention layer, combines
look-ahead mask and target padding mask (optional).
Shape (batch_size, tgt_seq_len, tgt_seq_len).
Returns:
The output tensor from the layer
(batch_size, tgt_seq_len, d_model).
"""
# 1. Masked Multi-Head Self-Attention
# The target mask (tgt_mask) prevents attending to
# future positions.
self_attn_output, _ = self.self_attn(query=x,
key=x,
value=x,
mask=tgt_mask)
# Apply residual connection and layer normalization
x = self.norm1(x + self.dropout(self_attn_output)) # Add -> Norm
# 2. Multi-Head Cross-Attention (Encoder-Decoder Attention)
# Query is from decoder (x), Key/Value are from encoder
# output (memory).
# The source mask (src_mask) prevents attending to padding
# in the encoder output.
cross_attn_output, _ = self.cross_attn(query=x,
key=memory,
value=memory,
mask=src_mask)
# Apply residual connection and layer normalization
x = self.norm2(x + self.dropout(cross_attn_output)) # Add -> Norm
# 3. Position-wise Feed-Forward Network
ff_output = self.feed_forward(x)
# Apply residual connection and layer normalization
x = self.norm3(x + self.dropout(ff_output)) # Add -> Norm
return x
Key points in the DecoderLayer
implementation:
self_attn
layer receives tgt_mask
. This mask is typically a combination of a padding mask (if the target sequence has padding) and a look-ahead mask (a lower triangular matrix) to enforce causality.cross_attn
layer uses the output of the first Add & Norm block (x
) as its query
. Importantly, the key
and value
come from the memory
argument, which represents the final output of the encoder stack. The src_mask
is used here to ignore padding tokens in the original source sequence (encoder input).Add & Norm
steps corresponding to the three sub-layers.With these EncoderLayer
and DecoderLayer
classes defined, we have the fundamental components needed to build the full Encoder and Decoder stacks, which simply involves creating multiple copies of these layers and passing the output of one layer as the input to the next. This stacking allows the model to learn increasingly complex representations of the input and target sequences. We will assemble these stacks in the next section.
© 2025 ApX Machine Learning