Choosing the right optimization algorithm is a significant factor when training large-scale models, particularly in distributed environments. While standard optimizers like Stochastic Gradient Descent (SGD) with momentum or Adam form the basis, scaling up training often necessitates optimizers designed to handle very large batch sizes or specific regularization needs effectively. The techniques discussed earlier in this chapter, such as data parallelism with pmap
and gradient accumulation, directly lead to larger effective batch sizes, influencing optimizer behavior.
Adam remains a widely used optimizer due to its adaptive learning rates. However, the standard implementation of L2 regularization in Adam is often suboptimal. AdamW modifies Adam by decoupling the weight decay calculation from the gradient update associated with the adaptive learning rates.
In standard Adam with L2 regularization, the decay term interacts with the adaptive moments (mt and vt), potentially leading to weights with large historical gradients decaying less than those with smaller gradients. AdamW applies the weight decay directly to the weights after the Adam step, behaving more like the weight decay used with SGD.
The conceptual update for a parameter θ looks like:
This decoupling often results in better generalization performance and more stable training for large models like Transformers compared to the original Adam implementation with L2 regularization. It has become a standard choice for many large language model training recipes.
When using massive batch sizes, often achieved through distributed training across many accelerators using pmap
, standard optimizers like Adam can sometimes become unstable or require careful learning rate tuning (especially extensive warm-up). LAMB was developed specifically to enable stable training with extremely large batch sizes (tens of thousands or more).
The main idea behind LAMB is to apply layer-wise normalization to the parameter updates. It computes the Adam update step similarly to AdamW but then normalizes the update for each layer based on the ratio of the norms of the weights and the Adam update for that layer.
Conceptually, for each layer l:
This layer-specific trust ratio scaling helps prevent updates from becoming excessively large or small for certain layers, which can happen with large batches where gradient variance is reduced, potentially leading Adam to take overly aggressive steps. LAMB has been shown effective in training models like BERT with very large batch sizes significantly faster than previously possible.
Beyond the choice of the core algorithm (AdamW, LAMB, etc.), other factors are essential for successful large-scale optimization:
Learning Rate Scheduling: Almost all large model training relies heavily on carefully designed learning rate schedules. Common strategies include:
Optimizer State Management: Adaptive optimizers like Adam(W) and LAMB maintain state (e.g., momentum and variance estimates) for each parameter. For large models, this state can consume a significant amount of memory, sometimes rivaling the model parameters themselves. In distributed settings using pmap
, this optimizer state must also be distributed (sharded) across the devices along with the parameters and gradients. Libraries like optax
(a popular gradient processing and optimization library for JAX) typically handle this automatically when used correctly within a pmap
-decorated function. Ensure your training setup correctly partitions the optimizer state across devices.
Libraries like optax
provide implementations for many common and advanced optimizers, integrating smoothly with JAX's functional programming model and transformations. Using optax
typically involves:
Here's a conceptual example of an update step using optax
:
import jax
import jax.numpy as jnp
import optax
# Assume 'params' are model parameters (e.g., a PyTree)
# Assume 'grads' are gradients computed by jax.grad
# Assume 'opt_state' is the current optimizer state
# Define the optimizer (e.g., AdamW with cosine decay schedule)
learning_rate_schedule = optax.warmup_cosine_decay_schedule(...)
optimizer = optax.adamw(learning_rate=learning_rate_schedule, weight_decay=0.01)
# Initialize optimizer state (typically done once outside the training loop)
# opt_state = optimizer.init(params)
# Inside the training step (potentially pmap'd)
@jax.jit
def update_step(params, grads, opt_state):
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state
# Apply the update
# params, opt_state = update_step(params, grads, opt_state)
When this update_step
function is used within a pmap
, optax
helps ensure that the gradient calculations, updates, and state management are handled correctly across devices, including necessary gradient aggregation (e.g., using lax.pmean
) before the optimizer update is computed.
Choosing and tuning the optimizer, along with its associated learning rate schedule and weight decay, is an iterative process. While optimizers like AdamW provide a strong starting point, and LAMB offers advantages for extremely large batches, experimentation is often needed to find the best combination for a specific large-scale training task, considering the model architecture, dataset characteristics, and available computational resources.
© 2025 ApX Machine Learning