Masterclass
While Xavier initialization provides a solid foundation for layers with symmetric activations like tanh, the situation changed significantly with the widespread adoption of the Rectified Linear Unit (ReLU) activation function. ReLU, defined as f(x)=max(0,x), introduces a non-symmetry: it outputs zero for all negative inputs. This means, on average, half of the activations flowing out of a ReLU unit might be zero. The assumptions underlying Xavier initialization, which aim to balance variance based on symmetric activations, don't perfectly account for this behavior. Consequently, using Xavier initialization with ReLU can still lead to a gradual decrease in variance as signals propagate forward, potentially slowing down training or contributing to vanishing gradients in very deep networks.
Recognizing this mismatch, Kaiming He et al. proposed an initialization scheme specifically designed for ReLU and its variants in their paper "Exploring Rectifiers: Surpassing Human-Level Performance on ImageNet Classification". The core idea is to explicitly account for the non-linearity introduced by ReLU when calculating the appropriate variance for the initial weights.
Let's consider a linear layer y=Wx+b. Assuming the inputs x have zero mean and the weights W are initialized independently with zero mean, the variance of the output yi​ (before activation) for a single neuron i is given by:
Var(yi​)=j=1∑nin​​Var(Wij​xj​)Assuming Wij​ and xj​ are independent, and E[xj​]=0:
Var(Wij​xj​)=E[Wij2​xj2​]−(E[Wij​xj​])2 Var(Wij​xj​)=E[Wij2​]E[xj2​]−(E[Wij​]E[xj​])2 Var(Wij​xj​)=Var(Wij​)Var(xj​)So, summing over the nin​ inputs (the fan-in):
Var(yi​)=nin​Var(W)Var(x)Now, let z=f(y) be the output after applying the ReLU activation f. The insight from He et al. is how ReLU affects the variance. If y is the output of the linear layer initialized with zero mean, it's symmetrically distributed around zero. ReLU sets negative values to zero. If we assume x comes from a previous ReLU layer, the variance calculation needs adjustment. However, focusing on the forward pass through the current layer with ReLU activation f, we have zi​=max(0,yi​). If yi​ has zero mean and is symmetric, E[zi2​]=21​E[yi2​]. Since E[yi​]=0, E[yi2​]=Var(yi​). Therefore, Var(zi​)≈E[zi2​]=21​Var(yi​).
Substituting the expression for Var(yi​):
Var(zi​)≈21​nin​Var(W)Var(x)To maintain stable signal propagation, we want the variance of the output of the activation (Var(zi​)) to be roughly equal to the variance of the input to the layer (Var(x)). This requires:
Var(zi​)=Var(x)⟹1≈21​nin​Var(W)Solving for the desired variance of the weights W:
Var(W)=nin​2​This is the fundamental result of Kaiming initialization for ReLU activations when considering the forward pass (fan-in mode). A similar derivation considering the backward pass (gradient flow) leads to Var(W)=nout​2​.
Based on this derived variance, we can initialize weights using either a normal or a uniform distribution.
Kaiming Normal Initialization: Weights are sampled from a normal distribution N(0,σ2), where the standard deviation σ is:
Kaiming Uniform Initialization: Weights are sampled from a uniform distribution U(−bound,bound), where the bound is calculated to match the desired variance:
The 'fan-in' mode is generally preferred as it maintains variance during the forward pass.
PyTorch provides convenient functions for Kaiming initialization within the torch.nn.init
module.
import torch
import torch.nn as nn
import math
# Example Linear Layer typical in a Transformer FFN
fan_in = 2048 # Example d_model
fan_out = 8192
# Example feed-forward dimension (often 4*d_model)
linear_layer = nn.Linear(fan_in, fan_out, bias=False)
# Biases often initialized to zero
# --- Kaiming Normal Initialization (fan_in, for ReLU) ---
# 'a=0' is the default for ReLU. Use a different value for Leaky ReLU
# slope.
# 'mode=fan_in' preserves variance in forward pass.
# 'nonlinearity=relu' specifies the gain calculation appropriate for
# ReLU.
nn.init.kaiming_normal_(
linear_layer.weight,
mode='fan_in',
nonlinearity='relu'
)
print("Kaiming Normal Initialized Weights (Shape, Sample):")
print(linear_layer.weight.data.shape)
print(linear_layer.weight.data[0, :5])
actual_var_normal = linear_layer.weight.data.var()
expected_var = 2.0 / fan_in
print(f"\nVariance (Normal): {actual_var_normal:.6f}")
print(f"Expected Variance (2 / fan_in): {expected_var:.6f}")
# --- Kaiming Uniform Initialization (fan_in, for ReLU) ---
linear_layer_uniform = nn.Linear(fan_in, fan_out, bias=False)
nn.init.kaiming_uniform_(
linear_layer_uniform.weight,
mode='fan_in',
nonlinearity='relu'
)
print("\nKaiming Uniform Initialized Weights (Shape, Sample):")
print(linear_layer_uniform.weight.data.shape)
print(linear_layer_uniform.weight.data[0, :5])
actual_var_uniform = linear_layer_uniform.weight.data.var()
# Expected variance is still 2 / fan_in
print(f"\nVariance (Uniform): {actual_var_uniform:.6f}")
print(f"Expected Variance (2 / fan_in): {expected_var:.6f}")
# --- Initializing Biases ---
# Biases are typically initialized to zero
bias_tensor = torch.zeros(fan_out)
print("\nBias Initialization (Example):")
print(bias_tensor[:5])
In the code, nonlinearity='relu'
tells the function to use the gain factor associated with ReLU, which is 2​. This factor arises directly from the derivation needing Var(W)=2/nin​. If you were using Leaky ReLU with a negative slope a
, you would set nonlinearity='leaky_relu'
and potentially the a
parameter, which adjusts the gain calculation accordingly. mode='fan_in'
ensures the variance calculation uses the number of input features (nin​).
Kaiming initialization is the standard choice for initializing the weight matrices within the position-wise feed-forward networks (FFNs) of a Transformer, as these typically use ReLU or its close relatives like GeLU or SwiGLU. While GeLU and SwiGLU aren't exactly ReLU, Kaiming initialization often serves as a good starting point. For embedding layers and the linear projections within attention mechanisms, different strategies might be employed (often closer to Xavier normal or simply a scaled standard normal distribution), but for the core FFN layers activated by ReLU-like functions, Kaiming initialization is critical for enabling the training of deep stacks.
By specifically addressing the properties of ReLU activations, Kaiming initialization provides a method for deep networks dominated by these units. It plays a significant role in preventing the signal variance from decaying rapidly, thereby facilitating the stable and efficient training of large models like modern Transformers.
Was this section helpful?
© 2025 ApX Machine Learning