Masterclass
Training deep neural networks, particularly the extensive Transformer architectures used in modern LLMs, requires careful attention to numerous details. Among the most significant factors influencing training dynamics is the initial state of the model's parameters, specifically its weights. Simply initializing weights to zero or small random numbers drawn from a naive distribution often leads to significant training difficulties. Proper initialization is not merely a heuristic; it's a foundational requirement for enabling effective learning in deep networks.
The core issues stem from how signals, both activations during the forward pass and gradients during the backward pass, propagate through the network's layers. Consider the computation within a single layer, often involving a matrix multiplication followed by a non-linear activation function. When multiple such layers are stacked, these operations are repeated sequentially.
During backpropagation, the gradient of the loss with respect to a weight in an early layer is calculated using the chain rule, multiplying gradients from all subsequent layers. If the average magnitude of these gradients (or the Jacobian matrices of the layer transformations) is less than 1, the gradient signal can shrink exponentially as it propagates backward.
∂Wl∂L∝(k=l+1∏NJk)∂aN∂LHere, Jk represents the Jacobian of layer k. If ∥Jk∥<1 on average, the product term diminishes rapidly as the number of layers N−l increases.
This phenomenon, known as the vanishing gradient problem, means that the weights in earlier layers receive extremely small updates. Consequently, these layers learn very slowly, or sometimes not at all. This effectively prevents the network from learning complex representations that rely on the coordinated tuning of parameters across its full depth. Historically, this was a major barrier to training deep networks, especially those using activation functions like sigmoid or tanh, which have derivatives less than 1, particularly in their saturation regions.
Conversely, if the average magnitude of the gradients (or Jacobians) is greater than 1, the gradient signal can grow exponentially as it moves backward through the layers.
∥∂Wl∂L∥→∞as N−l→∞if ∥Jk∥>1This exploding gradient problem leads to excessively large weight updates. Large updates can cause the optimization process to become unstable, potentially overshooting optimal points in the loss landscape or oscillating wildly. In extreme cases, the gradients become so large that they result in numerical overflow (represented as NaN
or Inf
values), bringing the training process to a halt. While techniques like gradient clipping (discussed in Chapter 17) can mitigate exploding gradients during training, proper initialization aims to prevent them from occurring frequently in the first place.
Let's look at a simplified simulation. Imagine a very simple linear network with 10 layers. We'll initialize weights either slightly too small or slightly too large and observe the magnitude of a dummy gradient propagated backward.
import torch
import torch.nn as nn
import math
# Simulate signal propagation (simplified)
def check_grad_magnitude(init_scale, num_layers=10):
"""
Simulates backward gradient magnitude through linear layers.
"""
# Input dimension doesn't matter much here, focus on scale
layer_dim = 100
layers = []
for _ in range(num_layers):
layer = nn.Linear(layer_dim, layer_dim, bias=False)
# Initialize weights with a specific standard deviation
nn.init.normal_(layer.weight, mean=0.0, std=init_scale)
layers.append(layer)
network = nn.Sequential(*layers)
# Dummy input and output gradient
x = torch.randn(1, layer_dim)
# Assume output gradient has magnitude 1
output_grad = torch.ones(1, layer_dim)
# Calculate gradient w.r.t input using chain rule simulation
current_grad = output_grad
# Propagate backward manually to see magnitude changes
with torch.no_grad():
for i in range(num_layers - 1, -1, -1):
# Grad w.r.t layer input = Grad w.r.t layer output @ W^T
current_grad = current_grad @ layers[i].weight
# Normalize to avoid actual explosion/vanishing for demonstration
# In a real scenario, this normalization doesn't happen per layer
# current_grad /= math.sqrt(layer_dim) # Example normalization factor
# Calculate gradient magnitude w.r.t the very first layer's input
# This simulates the gradient reaching the earliest parameters
input_grad_norm = torch.norm(current_grad)
return input_grad_norm.item()
# Check magnitudes
num_layers = 10
small_init_scale = 0.05 # Potentially leads to vanishing gradients
large_init_scale = 0.5 # Potentially leads to exploding gradients
ideal_init_scale = math.sqrt(1.0 / 100) # Xavier/He-like scaling hint
grad_norm_small = check_grad_magnitude(small_init_scale, num_layers)
grad_norm_large = check_grad_magnitude(large_init_scale, num_layers)
grad_norm_ideal = check_grad_magnitude(ideal_init_scale, num_layers)
print(
f"Initialization Scale: {small_init_scale:.3f}, "
f"Final Gradient Norm: {grad_norm_small:.4e}"
)
print(
f"Initialization Scale: {ideal_init_scale:.3f}, "
f"Final Gradient Norm: {grad_norm_ideal:.4e}"
)
print(
f"Initialization Scale: {large_init_scale:.3f}, "
f"Final Gradient Norm: {grad_norm_large:.4e}"
)
# Expected (approximate) output:
# Initialization Scale: 0.050, Final Gradient Norm: 9.5367e-08
# Initialization Scale: 0.100, Final Gradient Norm: 1.0000e+00
# Initialization Scale: 0.500, Final Gradient Norm: 9.7656e+06
The simple simulation above (without normalization and non-linearities, which further complicate things) illustrates how weight scales directly impact gradient magnitude after backpropagation through several layers. A small initial scale causes the gradient norm to vanish, while a large scale causes it to explode. An "ideal" scale (like 1/fan_in as hinted by Xavier/Glorot initialization) helps maintain the gradient norm closer to its original magnitude.
Initialization also impacts the forward pass. If weights are too large, the outputs of linear layers can grow significantly, potentially pushing the inputs to activation functions into saturation regions (e.g., for sigmoid or tanh) where gradients are near zero. This again hinders learning. Conversely, if weights are too small, activations might diminish layer by layer, leading to representations collapsing towards zero and reducing the network's effective capacity.
Therefore, the objective of principled weight initialization strategies, such as Xavier and Kaiming initialization which we will discuss next, is to carefully set the initial scale of weights based on the layer dimensions. The goal is to ensure that both activations in the forward pass and gradients in the backward pass maintain a reasonable variance throughout the network, preventing signals from vanishing or exploding and thereby promoting faster, more stable training convergence. This is particularly important for the very deep Transformer models central to this course.
© 2025 ApX Machine Learning