While optimizers like AdamW and techniques like learning rate scheduling significantly aid Transformer training, the sheer depth of these models and the nature of their computations can sometimes lead to another numerical challenge: exploding gradients. Unlike the vanishing gradients often discussed in the context of recurrent networks, exploding gradients refer to the situation where the magnitude of gradients becomes excessively large during backpropagation. These large gradients can cause correspondingly large updates to the model weights during the optimizer step, potentially leading to numerical overflow, unstable training dynamics (e.g., oscillating or diverging loss), or even the destruction of previously learned information.
Exploding gradients can arise in deep networks because gradients are multiplied layer by layer during backpropagation. If gradient norms are consistently greater than 1, their product can grow exponentially with network depth. While components like Layer Normalization help mitigate this, sudden large gradients can still occur, perhaps due to specific input sequences or interactions within the attention mechanism or feed-forward networks, particularly early in training before the model parameters have settled.
Gradient clipping is a direct technique employed to counteract this issue. It operates by imposing an upper limit on the magnitude (norm) of the gradients before the optimizer updates the model weights. If the overall norm of the gradients across all model parameters exceeds a predefined threshold, the gradients are rescaled downwards to match that threshold. This prevents single large gradient events from destabilizing the training process.
The most common approach is clipping by global norm. This involves calculating the L2 norm (Euclidean norm) of the gradient vector formed by concatenating all gradients for all trainable parameters in the model. Let G represent this global gradient vector containing gradients ∂θi∂L for all parameters θi. The L2 norm is calculated as:
global_norm=i∑(∂θi∂L)2We then compare this global_norm
to a predefined max_norm
threshold. If global_norm
exceeds max_norm
, the entire gradient vector G is rescaled:
If global_norm
is less than or equal to max_norm
, the gradients remain unchanged: Gclipped=G.
An important property of this method is that it scales all gradients uniformly. This means the direction of the overall gradient vector in the parameter space is preserved; only its magnitude is limited. This helps prevent overly large steps without drastically altering the optimization path indicated by the gradient.
Flow diagram illustrating the gradient clipping process by global norm.
Deep learning frameworks provide built-in functions for gradient clipping, making it easy to integrate into the training loop. It's typically applied after the loss.backward()
call (which computes the gradients) and before the optimizer.step()
call (which updates the weights based on the gradients).
# Conceptual PyTorch example
# Assume model, optimizer, loss are defined
# Forward pass
outputs = model(inputs)
loss = compute_loss(outputs, targets)
# Backward pass to compute gradients
loss.backward()
# Gradient Clipping (applied *after* backward, *before* step)
gradient_threshold = 1.0
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=gradient_threshold)
# Optimizer step (uses potentially clipped gradients)
optimizer.step()
# Zero gradients for next iteration
optimizer.zero_grad()
In TensorFlow, the equivalent might involve using tf.clip_by_global_norm
often within a custom training loop or implicitly handled by higher-level APIs when specified.
The max_norm
value is a hyperparameter that usually requires some tuning. Common values often fall in the range of [0.5, 5.0], with 1.0 being a frequent starting point.
Monitoring the actual gradient norms during training (if your framework or monitoring tools allow) can help inform the choice of threshold. If you observe frequent clipping events or very large norms before clipping, it confirms the need for clipping and helps adjust the threshold. Otherwise, observing the training loss curve for stability is the primary empirical guide.
Gradient clipping is a standard practice in training large Transformer models. It acts as a safeguard, complementing other techniques like careful initialization, normalization layers, and learning rate schedules to promote stable and effective convergence, especially for deep architectures where gradient dynamics can be complex.
© 2025 ApX Machine Learning