Masterclass
While principled initialization methods like Kaiming and Xavier provide excellent starting points for most layers in deep networks, special consideration is often given to the final layer, particularly the output projection layer that maps the last hidden state to vocabulary logits in a language model. The initialization of this layer can significantly impact the initial loss values and the stability of the training process, especially in the very first few steps.
The output layer directly computes the pre-softmax logits for every token in the vocabulary. If these initial logits have a large variance, several issues can arise:
Consider the standard Kaiming or Xavier initialization. These methods aim to preserve variance through the network layers. However, for the final layer mapping a hidden dimension dmodel​ to a large vocabulary V, the resulting scale might still be too large for stable initial training, especially given the scale of V.
A common practice is to initialize the weight matrix of the final linear layer with a smaller standard deviation compared to what Kaiming or Xavier initialization would typically suggest. The bias term is often initialized to zero.
For instance, instead of using the default standard deviation from Kaiming uniform/normal, one might manually specify a smaller standard deviation, often scaling it inversely with the network depth or based on empirical findings. Some architectures or codebases adopt specific small standard deviations like 0.02 or scale the standard Kaiming/Xavier standard deviation by a factor (e.g., 0.5).
Another perspective comes from architectures where the final layer is part of a residual connection pathway. In models like GPT-2/3, layers contributing to the residual stream sometimes receive initializations scaled by the number of layers to prevent the residual signal from growing too large. While this often applies to intermediate feed-forward or attention output projections within residual blocks, a similar principle might be applied to the final output projection if stability issues are observed. The core idea is to keep the initial output magnitudes controlled.
Let's see how you might apply a custom, smaller initialization to the final linear layer in a PyTorch model. Assume your model has a final layer named lm_head
which is an instance of torch.nn.Linear
.
import torch
import torch.nn as nn
import math
class SimpleTransformerLM(nn.Module):
def __init__(self, vocab_size, d_model, nhead, num_layers,
dim_feedforward):
super().__init__()
self.d_model = d_model
# ... (embedding, positional encoding, transformer encoder layers) ...
self.embedding = nn.Embedding(vocab_size, d_model)
encoder_layer = nn.TransformerEncoderLayer(
d_model, nhead, dim_feedforward, batch_first=True
)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer,
num_layers)
# Often bias is false or zeroed
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
self.apply(self._init_weights) # Apply initialization recursively
def _init_weights(self, module):
if isinstance(module, nn.Linear):
# Standard Kaiming initialization for most linear layers
nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
# Or appropriate nonlinearity
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
# Standard initialization for embeddings
# Example std dev
nn.init.normal_(module.weight, mean=0.0, std=0.02)
# Special handling for the final output layer (lm_head)
# Check if the module *is* the lm_head instance
if hasattr(self, 'lm_head') and module is self.lm_head:
# Use a smaller standard deviation for the final layer weights
# Example: Scale Kaiming std dev or use a fixed small value
std_dev = 0.01 # Or some other empirically determined small value
# Alternative: Scale based on Kaiming std dev for the layer
# kaiming_std = math.sqrt(2.0 / (module.in_features *
# (1 + math.pow(0, 2)))) # Example for ReLU
# std_dev = kaiming_std * 0.5 # Scale it down
print(
f"Applying special initialization to lm_head with std={std_dev}"
)
nn.init.normal_(module.weight, mean=0.0, std=std_dev)
if module.bias is not None:
# Ensure bias is zero if it exists
nn.init.zeros_(module.bias)
def forward(self, src):
# ... (forward pass logic) ...
embedded = self.embedding(src) * math.sqrt(self.d_model)
# Add positional encoding here
output = self.transformer_encoder(embedded)
logits = self.lm_head(output)
return logits
# Example usage:
vocab_size = 10000
d_model = 512
nhead = 8
num_layers = 6
dim_feedforward = 2048
model = SimpleTransformerLM(
vocab_size, d_model, nhead, num_layers, dim_feedforward
)
# You would see the print statement during model instantiation:
# "Applying special initialization to lm_head with std=0.01"
In this example, the _init_weights
function is applied recursively to all modules. It first applies standard Kaiming normal initialization to generic linear layers and a normal initialization to embeddings. Then, it specifically checks if the current module being initialized is the lm_head
instance. If it is, it overrides the standard initialization with a normal distribution using a much smaller standard deviation (std=0.01
in this case).
Choosing the exact standard deviation or scaling factor for the final layer often involves some empirical tuning based on the specific model architecture, vocabulary size, and initial training observations. Monitoring the initial loss values and gradient norms can help determine if the chosen initialization is appropriate. The goal is to start the training in a stable regime where gradients are informative but not excessively large.
© 2025 ApX Machine Learning