Masterclass
Having established the principles of Xavier and Kaiming initialization in the previous sections, let's now examine how these techniques are applied specifically to the different types of layers within a standard Transformer model. Ensuring each component is initialized appropriately is significant for maintaining stable signal propagation and enabling effective gradient-based learning in these deep architectures.
Transformer models begin with embedding layers that convert input token IDs and their positions into dense vector representations.
Token Embeddings (nn.Embedding
): This layer maps discrete token IDs to continuous vectors. A common practice is to initialize these embedding weights from a normal distribution with a mean of 0 and a relatively small standard deviation, such as N(0,0.02). This approach provides initial diversity in embeddings without introducing excessively large values that could destabilize early training. While it doesn't strictly follow the fan-in/fan-out logic of Xavier/Kaiming (as the input is sparse), this empirical method works well in practice. Some implementations might adapt Xavier initialization, treating the embedding dimension as the output size.
Positional Embeddings: If using learned positional embeddings (also often an nn.Embedding
layer), similar initialization strategies apply, typically drawing weights from N(0,σ2) with a small σ. For sinusoidal positional encodings, initialization isn't required as these are fixed, deterministic values.
Here's how you might initialize an embedding layer in PyTorch:
import torch
import torch.nn as nn
vocab_size = 30000
hidden_dim = 768
max_position_embeddings = 512
# Token Embeddings
token_embedding_layer = nn.Embedding(vocab_size, hidden_dim)
nn.init.normal_(token_embedding_layer.weight, mean=0.0, std=0.02)
# Learned Positional Embeddings (Example)
positional_embedding_layer = nn.Embedding(
max_position_embeddings, hidden_dim
)
nn.init.normal_(positional_embedding_layer.weight, mean=0.0, std=0.02)
print("Token Embedding Layer Weight Shape:", token_embedding_layer.weight.shape)
print(
"Positional Embedding Layer Weight Shape:",
positional_embedding_layer.weight.shape
)
The core of the Transformer lies in its self-attention and cross-attention mechanisms. These involve several linear projection layers:
Query (Q), Key (K), Value (V) Projections (nn.Linear
): These layers project the input embeddings into the Q, K, and V spaces. Since they are standard linear transformations often followed by operations that benefit from controlled variance (like the scaled dot-product attention), Kaiming initialization is frequently a suitable choice, especially Kaiming uniform or normal. This helps maintain variance through the initial projection steps.
Output Projection (nn.Linear
): After computing attention outputs (often from multiple heads), another linear layer projects the concatenated outputs back to the model's hidden dimension. Similar to Q/K/V projections, Kaiming initialization is a common and reasonable default for this layer.
Consider a linear projection layer within a multi-head attention block:
# Example: Initializing a QKV projection layer
hidden_dim = 768
projection_layer = nn.Linear(hidden_dim, hidden_dim * 3) # Combined Q, K, V
# Apply Kaiming uniform initialization
# Assumes a ReLU-like non-linearity might follow implicitly
# in the broader computation graph or feed-forward layer
nn.init.kaiming_uniform_(
projection_layer.weight,
a=0,
mode='fan_in',
nonlinearity='leaky_relu'
)
# Initialize bias to zero (common practice)
if projection_layer.bias is not None:
nn.init.zeros_(projection_layer.bias)
print("Projection Layer Weight Shape:", projection_layer.weight.shape)
# Example: Initializing the output projection layer
output_projection = nn.Linear(hidden_dim, hidden_dim)
# Using Kaiming normal as another option
nn.init.kaiming_normal_(
output_projection.weight,
mode='fan_out',
nonlinearity='relu'
)
if output_projection.bias is not None:
nn.init.zeros_(output_projection.bias)
print("Output Projection Weight Shape:", output_projection.weight.shape)
Note that the choice of mode
(fan_in
or fan_out
) and nonlinearity
can be adjusted based on theoretical considerations or empirical results. fan_in
preserves variance in the forward pass, while fan_out
preserves it in the backward pass. For ReLU-like activations (including GeLU, SwiGLU), specifying nonlinearity='relu'
or nonlinearity='leaky_relu'
adjusts the scaling factor appropriately.
Each Transformer block contains an FFN, typically consisting of two linear layers with a non-linear activation function in between (e.g., ReLU, GeLU, SwiGLU).
First Linear Layer (nn.Linear
): This layer usually expands the dimensionality. Kaiming initialization is strongly recommended here, matched to the non-linearity used (e.g., nonlinearity='relu'
for ReLU/GeLU). This ensures the variance of the outputs after the activation function remains controlled.
Second Linear Layer (nn.Linear
): This layer projects the dimension back down to the model's hidden size. While Kaiming initialization can be used, some research and implementations (like those following GPT-2/3 practices) suggest initializing this layer with a smaller variance. The rationale is often tied to the residual connection that follows the FFN block. By scaling down the initialization of the layer contributing to the residual branch, it helps stabilize training, particularly in very deep networks. This scaling might be proportional to 1/N​, where N is the number of residual blocks or layers.
hidden_dim = 768
ffn_intermediate_dim = hidden_dim * 4 # Common expansion factor
num_layers = 12 # Example number of transformer layers
# First FFN layer (expansion)
ffn_layer1 = nn.Linear(hidden_dim, ffn_intermediate_dim)
# Use Kaiming matching the activation (e.g., GeLU which is ReLU-like)
nn.init.kaiming_uniform_(
ffn_layer1.weight,
a=0,
mode='fan_in',
nonlinearity='leaky_relu'
)
if ffn_layer1.bias is not None:
nn.init.zeros_(ffn_layer1.bias)
# Second FFN layer (projection back)
ffn_layer2 = nn.Linear(ffn_intermediate_dim, hidden_dim)
# Option 1: Standard Kaiming
# nn.init.kaiming_uniform_(ffn_layer2.weight, a=0, mode='fan_out',
# nonlinearity='linear') # No activation after this
# Option 2: Scaled Initialization (GPT-style for residual connections)
# Initialize with smaller standard deviation, scaled by number of layers
residual_scaling_factor = 2 * num_layers # Example scaling factor heuristic
std_dev = 0.02 / (residual_scaling_factor**0.5)
nn.init.normal_(ffn_layer2.weight, mean=0.0, std=std_dev)
if ffn_layer2.bias is not None:
nn.init.zeros_(ffn_layer2.bias)
print("FFN Layer 1 Weight Shape:", ffn_layer1.weight.shape)
print("FFN Layer 2 Weight Shape:", ffn_layer2.weight.shape)
Layer Normalization (nn.LayerNorm
) has learnable affine parameters: gain (γ) and bias (β). Standard practice is to initialize γ to 1 and β to 0. This makes the LayerNorm initially behave close to an identity transformation for the normalized outputs, allowing the network to learn deviations during training if necessary.
layer_norm = nn.LayerNorm(hidden_dim)
# Default initialization in PyTorch nn.LayerNorm is already ones for weight
# (gamma) and zeros for bias (beta), but explicit initialization looks like:
# nn.init.ones_(layer_norm.weight) # Gamma
# nn.init.zeros_(layer_norm.bias) # Beta
print("LayerNorm Gamma (weight) Shape:", layer_norm.weight.shape)
print("LayerNorm Beta (bias) Shape:", layer_norm.bias.shape)
The final layer of a decoder-only or encoder-decoder Transformer typically projects the hidden states to the vocabulary size, often followed by a softmax for probability distribution.
nn.Linear
): This layer maps the final hidden dimension to the size of the vocabulary. Similar to the second FFN layer, directly applying standard Xavier or Kaiming might result in initial outputs that are too large, potentially leading to overly confident predictions and instability early in training. It's common practice to initialize this layer with a smaller standard deviation, similar to token embeddings (e.g., N(0,0.02)), or potentially tie its weights to the token embedding matrix (weight tying), although initialization still applies if they aren't tied.output_projection_vocab = nn.Linear(hidden_dim, vocab_size)
# Initialize with potentially smaller scale compared to internal layers
nn.init.normal_(output_projection_vocab.weight, mean=0.0, std=0.02)
if output_projection_vocab.bias is not None:
nn.init.zeros_(output_projection_vocab.bias)
print("Final Output Projection Shape:", output_projection_vocab.weight.shape)
In summary, while general principles like Xavier and Kaiming provide excellent starting points, effectively initializing a deep Transformer often involves applying these methods thoughtfully to each component, sometimes with empirical adjustments (like smaller standard deviations for embeddings or specific residual layers) based on architectural choices and insights from training dynamics observed in large models. Careful initialization sets the stage for more stable and efficient training.
© 2025 ApX Machine Learning