Masterclass
With the individual building blocks—EncoderLayer
and DecoderLayer
—now implemented, we can proceed to assemble the complete Transformer model. This involves stacking these layers, adding the necessary embedding layers, incorporating positional encoding, and defining the final output projection.
Recall that the Transformer architecture follows an encoder-decoder structure. The encoder processes the input sequence and generates a context-rich representation (often called memory
). The decoder then uses this memory along with the target sequence (during training) or the previously generated tokens (during inference) to produce the output sequence.
Let's define a PyTorch nn.Module
for the complete Transformer.
import torch
import torch.nn as nn
import math
# Assume the following modules are defined in previous sections/files:
# from .encoder import EncoderLayer
# from .decoder import DecoderLayer
# from .attention import MultiHeadAttention
# from .feed_forward import PositionwiseFeedForward
# from .embeddings import PositionalEncoding, Embeddings # Assuming combined Embedding + Positional
# Placeholder definitions for required classes (replace with actual imports)
class EncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward, dropout):
super().__init__()
# Example placeholder
self.self_attn = nn.MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=True
)
self.feed_forward = nn.Sequential( # Example placeholder
nn.Linear(d_model, dim_feedforward),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, src, src_mask=None, src_key_padding_mask=None):
# Simplified forward for placeholder
src2, _ = self.self_attn(src, src, src, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)
src = src + self.dropout(src2)
src = self.norm1(src)
src2 = self.feed_forward(src)
src = src + self.dropout(src2)
src = self.norm2(src)
return src
class DecoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward, dropout):
super().__init__()
# Example placeholder
self.self_attn = nn.MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=True
)
# Example placeholder
self.multihead_attn = nn.MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=True
)
self.feed_forward = nn.Sequential( # Example placeholder
nn.Linear(d_model, dim_feedforward),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
tgt_key_padding_mask=None, memory_key_padding_mask=None):
# Simplified forward for placeholder
tgt2, _ = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)
tgt = tgt + self.dropout(tgt2)
tgt = self.norm1(tgt)
tgt2, _ = self.multihead_attn(
tgt, memory, memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask
)
tgt = tgt + self.dropout(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.feed_forward(tgt)
tgt = tgt + self.dropout(tgt2)
tgt = self.norm3(tgt)
return tgt
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
)
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
# Use register_buffer so pe is not a model parameter
self.register_buffer('pe', pe)
def forward(self, x):
"""
Args:
x: Tensor, shape [batch_size, seq_len, embedding_dim]
"""
# Adapt positional encoding to batch format [batch_size, seq_len, embedding_dim]
# Original pe shape is [max_len, 1, embedding_dim]. We need [1, seq_len, embedding_dim] or compatible.
# Slice pe for the current sequence length and transpose.
# Shape becomes [1, seq_len, embedding_dim]
pe_for_seq = self.pe[:x.size(1), :].permute(1, 0, 2)
x = x + pe_for_seq
return self.dropout(x)
# --- Main Transformer Class ---
class TransformerModel(nn.Module):
"""
Full Transformer model implementation.
"""
def __init__(self, src_vocab_size: int, tgt_vocab_size: int,
d_model: int, nhead: int, num_encoder_layers: int,
num_decoder_layers: int, dim_feedforward: int,
dropout: float = 0.1, max_len: int = 5000):
"""
Args:
src_vocab_size: Size of the source vocabulary.
tgt_vocab_size: Size of the target vocabulary.
d_model: Dimension of the embeddings and model layers.
nhead: Number of attention heads.
num_encoder_layers: Number of stacked encoder layers.
num_decoder_layers: Number of stacked decoder layers.
dim_feedforward: Dimension of the feed-forward network hidden layer.
dropout: Dropout rate.
max_len: Maximum sequence length for positional encoding.
"""
super().__init__()
self.d_model = d_model
self.src_tok_emb = nn.Embedding(src_vocab_size, d_model)
self.tgt_tok_emb = nn.Embedding(tgt_vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model, dropout, max_len)
# Use nn.ModuleList for lists of layers
self.encoder_layers = nn.ModuleList([
EncoderLayer(d_model, nhead, dim_feedforward, dropout)
for _ in range(num_encoder_layers)
])
self.decoder_layers = nn.ModuleList([
DecoderLayer(d_model, nhead, dim_feedforward, dropout)
for _ in range(num_decoder_layers)
])
# Final linear layer to project decoder output to vocabulary size
self.generator = nn.Linear(d_model, tgt_vocab_size)
# Optional: Weight tying between target embedding and final linear layer
# self.tgt_tok_emb.weight = self.generator.weight # Requires same dimension
self._reset_parameters()
def _reset_parameters(self):
"""Initiate parameters in the transformer model."""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def encode(self, src: torch.Tensor, src_mask: torch.Tensor = None,
src_key_padding_mask: torch.Tensor = None) -> torch.Tensor:
"""
Passes the source sequence through the encoder stack.
Args:
src: Source sequence tensor (batch_size, src_seq_len).
src_mask: Mask for source sequence attention (src_seq_len, src_seq_len).
Prevents attention to future positions if needed (usually not for encoder).
src_key_padding_mask: Mask for padding tokens in the source sequence (batch_size, src_seq_len).
Returns:
Encoder output tensor (batch_size, src_seq_len, d_model).
"""
# Embed tokens and add positional encoding
src_emb = self.src_tok_emb(src) * math.sqrt(self.d_model)
src_emb = self.pos_encoder(src_emb)
# Pass through each encoder layer
memory = src_emb
for layer in self.encoder_layers:
memory = layer(memory, src_mask=src_mask,
src_key_padding_mask=src_key_padding_mask)
return memory
def decode(self, tgt: torch.Tensor, memory: torch.Tensor,
tgt_mask: torch.Tensor = None,
memory_mask: torch.Tensor = None,
tgt_key_padding_mask: torch.Tensor = None,
memory_key_padding_mask: torch.Tensor = None) -> torch.Tensor:
"""
Passes the target sequence and encoder memory through the decoder stack.
Args:
tgt: Target sequence tensor (batch_size, tgt_seq_len).
memory: Encoder output tensor (batch_size, src_seq_len, d_model).
tgt_mask: Mask for target sequence self-attention (tgt_seq_len, tgt_seq_len).
Prevents attention to future positions.
memory_mask: Mask for encoder-decoder attention (tgt_seq_len, src_seq_len).
Usually not needed unless specific cross-attention masking is required.
tgt_key_padding_mask: Mask for padding tokens in the target sequence (batch_size, tgt_seq_len).
memory_key_padding_mask: Mask for padding tokens in the source sequence,
used in encoder-decoder attention (batch_size, src_seq_len).
Returns:
Decoder output tensor (batch_size, tgt_seq_len, d_model).
"""
# Embed tokens and add positional encoding
tgt_emb = self.tgt_tok_emb(tgt) * math.sqrt(self.d_model)
tgt_emb = self.pos_encoder(tgt_emb)
# Pass through each decoder layer
output = tgt_emb
for layer in self.decoder_layers:
output = layer(output, memory, tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask)
return output
def forward(self, src: torch.Tensor, tgt: torch.Tensor,
src_mask: torch.Tensor = None,
tgt_mask: torch.Tensor = None,
memory_mask: torch.Tensor = None,
src_key_padding_mask: torch.Tensor = None,
tgt_key_padding_mask: torch.Tensor = None,
memory_key_padding_mask: torch.Tensor = None) -> torch.Tensor:
"""
Full forward pass through the Transformer model.
Args:
src: Source sequence tensor (batch_size, src_seq_len).
tgt: Target sequence tensor (batch_size, tgt_seq_len).
src_mask: Mask for source sequence attention.
tgt_mask: Mask for target sequence self-attention.
memory_mask: Mask for encoder-decoder attention.
src_key_padding_mask: Padding mask for source sequence.
tgt_key_padding_mask: Padding mask for target sequence.
memory_key_padding_mask: Padding mask for source sequence used in cross-attention.
Returns:
Output logits tensor (batch_size, tgt_seq_len, tgt_vocab_size).
"""
memory = self.encode(src, src_mask, src_key_padding_mask)
decoder_output = self.decode(tgt, memory, tgt_mask, memory_mask,
tgt_key_padding_mask,
memory_key_padding_mask)
logits = self.generator(decoder_output)
return logits
Initialization (__init__
):
nn.Embedding
) for the source and target vocabularies. The size of these embeddings is d_model
.PositionalEncoding
module is created. Since its calculation is independent of vocabulary, it can be shared. Note the multiplication by sqrt(d_model)
applied to the token embeddings before adding positional encoding, as done in the original paper.nn.ModuleList
to hold the stack of EncoderLayer
and DecoderLayer
instances. This ensures that the layers are properly registered as submodules.nn.Linear
layer, self.generator
, projects the d_model
-dimensional output of the decoder stack into logits over the target vocabulary (tgt_vocab_size
)._reset_parameters
) using Xavier uniform initialization is applied, a common practice for Transformers.Encoding (encode
):
src
), applies embedding and positional encoding, and then sequentially passes the result through each EncoderLayer
in the stack.src_mask
, src_key_padding_mask
) which are passed down to the underlying attention mechanisms within each EncoderLayer
. src_key_padding_mask
is important to prevent attention computation on padding tokens.memory
) represents the processed source sequence context.Decoding (decode
):
tgt
), the memory
from the encoder, and various masks.DecoderLayer
. Each DecoderLayer
performs self-attention on the target sequence (using tgt_mask
and tgt_key_padding_mask
) and cross-attention with the encoder memory
(using memory_mask
and memory_key_padding_mask
). The tgt_mask
is particularly important here to prevent the decoder from attending to future tokens during training (causal masking).Forward Pass (forward
):
forward
method orchestrates the entire process.encode
to get the memory
.decode
with the target sequence and the memory
.generator
linear layer to the decoder output to produce the final logits for each position in the target sequence. These logits can then be used with a loss function (like Cross-Entropy) during training or processed further (e.g., using argmax or sampling) during inference.The following diagram illustrates the high-level data flow within the assembled Transformer model:
High-level data flow in the Transformer model, showing the path from input tokens through embeddings, encoder, decoder, and final output projection.
This complete TransformerModel
class provides a functional implementation based on the components developed previously. It encapsulates the core architecture, ready for integration into a training loop. The subsequent chapters will build upon this foundation, exploring how to scale this architecture, train it effectively on large datasets, and optimize it for various tasks.
© 2025 ApX Machine Learning