Masterclass
While the mathematical definitions for operations like matrix multiplication or differentiation are precise, their implementation on computers involves finite-precision floating-point numbers (like 32-bit float
or 16-bit half
). This finite precision can lead to subtle, and sometimes not-so-subtle, problems during the training of deep neural networks, particularly the very deep architectures common in large language models. These numerical stability issues can drastically impede or even halt the training process if not properly addressed.
During backpropagation, gradients are calculated using the chain rule, propagating from the output layer back through the network. Each step involves multiplying by the local gradient of that layer's operation (including the activation function's derivative) and the layer's weights. Consider a deep network with many layers. If the magnitudes of these gradients (particularly the activation function derivatives or weight matrices) are consistently less than 1, the gradient signal can shrink exponentially as it propagates backward.
∂W1​∂L​=∂outN​∂L​…∂in3​∂out3​​∂out2​∂in3​​∂in2​∂out2​​∂W1​∂in2​​If many terms in this chain product have magnitudes less than 1, the resulting gradient ∂W1​∂L​ for early layers (like W1​) can become extremely small, approaching zero.
This phenomenon is known as the vanishing gradient problem. When gradients vanish, the weights in the earlier layers of the network receive virtually no updates, and the network effectively stops learning meaningful representations from the data in those layers. This was a significant obstacle in training early deep networks, especially those using activation functions like sigmoid or tanh, whose derivatives saturate (approach zero) for large positive or negative inputs.
import torch
import matplotlib.pyplot as plt
# Example of Sigmoid derivative saturation
x = torch.linspace(-10, 10, 200)
sigmoid_x = torch.sigmoid(x)
sigmoid_grad = sigmoid_x * (1 - sigmoid_x) # Derivative of sigmoid
# Plotting
fig, ax1 = plt.subplots()
color = '#4263eb' # blue
ax1.set_xlabel('Input Value (x)')
ax1.set_ylabel('Sigmoid Activation', color=color)
ax1.plot(x, sigmoid_x, color=color, label='Sigmoid(x)')
ax1.tick_params(axis='y', labelcolor=color)
ax1.grid(True, linestyle=':')
ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis
color = '#f03e3e' # red
ax2.set_ylabel('Sigmoid Derivative', color=color)
ax2.plot(
x,
sigmoid_grad,
color=color,
linestyle='--',
label='d(Sigmoid)/dx'
)
ax2.tick_params(axis='y', labelcolor=color)
ax2.set_ylim(bottom=0) # Derivative is non-negative
fig.tight_layout() # otherwise the right y-label is slightly clipped
#plt.title("Sigmoid Activation and its Derivative") # Removed title for direct display
#plt.show() # Commented out for direct JSON output
# Create Plotly JSON representation
plotly_json = {
"data": [
{
"x": x.tolist()[::10], # Sample points to keep JSON small
"y": sigmoid_x.tolist()[::10],
"type": "scatter",
"mode": "lines",
"name": "Sigmoid(x)",
"line": {"color": "#4263eb"}
},
{
"x": x.tolist()[::10],
"y": sigmoid_grad.tolist()[::10],
"type": "scatter",
"mode": "lines",
"name": "d(Sigmoid)/dx",
"yaxis": "y2",
"line": {"color": "#f03e3e", "dash": "dash"}
}
],
"layout": {
"xaxis": {"title": "Input Value (x)"},
"yaxis": {
"title": "Sigmoid Activation",
"titlefont": {"color": "#4263eb"},
"tickfont": {"color": "#4263eb"},
"gridcolor": "#e9ecef"
},
"yaxis2": {
"title": "Sigmoid Derivative",
"titlefont": {"color": "#f03e3e"},
"tickfont": {"color": "#f03e3e"},
"overlaying": "y",
"side": "right",
"range": [0, 0.3],
"gridcolor": "#e9ecef"
},
"legend": {"x": 0.1, "y": 0.9},
"margin": {"l": 50, "r": 50, "t": 20, "b": 40}
}
}
The derivative of the sigmoid function is small (max 0.25) and approaches zero for large positive or negative inputs. Multiplying many small numbers during backpropagation leads to vanishing gradients.
Conversely, if the terms in the chain rule product (gradients, weights) are consistently greater than 1 in magnitude, the gradient signal can grow exponentially as it propagates backward. This leads to the exploding gradient problem.
Exploding gradients result in excessively large updates to the model weights (θt+1​=θt​−η∇θ​J(θ)). These large updates can cause the optimization process to become unstable, oscillating wildly or diverging completely. In practice, this often manifests as the loss function suddenly shooting up to NaN
(Not a Number) or Inf
(Infinity) during training, as the numerical values exceed the representable range of the floating-point type. This is particularly problematic with recurrent connections or very deep networks where the same weight matrices might be multiplied many times.
Standard deep learning typically uses 32-bit floating-point numbers (FP32 or float
). However, training large language models often leverages lower-precision formats like 16-bit floating-point (FP16 or half
) or BFloat16 (BF16) to reduce memory consumption and accelerate computation, especially on hardware with specialized units like NVIDIA's Tensor Cores.
These lower-precision formats have a significantly smaller representable range and less precision than FP32.
Inf
) or underflow (values becoming zero). Small gradients, common in the vanishing gradient scenario, can easily become zero in FP16, halting learning. Large gradients or intermediate activation values can exceed the maximum representable value, causing Inf
or NaN
.Using these formats requires careful handling to maintain numerical stability, which we will explore further in Chapter 20.
Fortunately, several techniques have been developed to combat these stability issues, forming a standard toolkit for training deep models:
Careful Initialization: Initializing weights appropriately helps prevent gradients from vanishing or exploding right from the start. Techniques like Xavier/Glorot and Kaiming initialization (Chapter 12) set initial weight scales based on layer dimensions.
Normalization Layers: Layers like Batch Normalization or, more commonly in Transformers, Layer Normalization (Chapter 4) rescale activations within a layer to have zero mean and unit variance. This helps keep activations and gradients within reasonable ranges, stabilizing training.
Gradient Clipping: This technique directly addresses exploding gradients by capping the maximum magnitude or norm of the gradients before the weight update step. If the gradient norm exceeds a threshold, it's rescaled downwards. (Chapter 17).
import torch
import torch.nn as nn
# Example parameters and gradients (replace with your model's)
param1 = torch.randn(100, 100, requires_grad=True)
param2 = torch.randn(50, 100, requires_grad=True)
parameters = [param1, param2]
# Simulate gradients (e.g., after loss.backward())
if param1.grad is None: # Create dummy grads if none exist
param1.grad = torch.randn_like(param1) * 100
if param2.grad is None:
param2.grad = torch.randn_like(param2) * 50
# Calculate total gradient norm
total_norm = 0
for p in parameters:
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
print(f"Original Gradient Norm: {total_norm:.2f}")
# Apply gradient clipping (using PyTorch utility)
max_norm = 1.0
nn.utils.clip_grad_norm_(parameters, max_norm)
# Calculate norm after clipping
clipped_total_norm = 0
for p in parameters:
if p.grad is not None:
param_norm = p.grad.data.norm(2)
clipped_total_norm += param_norm.item() ** 2
clipped_total_norm = clipped_total_norm ** 0.5
print(f"Clipped Gradient Norm: {clipped_total_norm:.2f}")
# Expected output will show the original norm likely > 1.0
# and the clipped norm very close to 1.0
Activation Functions: Using non-saturating activation functions like ReLU (Rectified Linear Unit) or its variants (GeLU, SwiGLU) helps mitigate the vanishing gradient problem compared to sigmoid or tanh (Chapter 11).
Mixed Precision Training Techniques: Methods like loss scaling are used specifically with FP16 to dynamically scale the loss function, effectively scaling up the gradients during backpropagation to prevent underflow, before scaling the gradients back down before the weight update (Chapter 20).
Understanding these potential numerical pitfalls and the strategies to address them is fundamental when working with the scale and depth characteristic of modern large language models. Without careful consideration of numerical stability, training these powerful models would be practically infeasible.
© 2025 ApX Machine Learning