Selecting the right hyperparameters for your optimizer is fundamental to successfully training large language models. While adaptive optimizers like AdamW simplify some aspects compared to vanilla SGD, finding the optimal settings for learning rate, momentum terms, epsilon, and weight decay remains a significant part of the training process, especially given the scale and cost associated with LLM training. These hyperparameters directly influence convergence speed, stability, and the final performance of the model. Building on our discussion of AdamW, learning rate schedules, and gradient clipping, let's examine how to approach choosing values for these critical parameters.
Learning Rate (η)
The learning rate is arguably the single most important hyperparameter. It dictates the step size taken during gradient descent. Set it too high, and training can become unstable, leading to divergence or oscillating loss. Set it too low, and training will be impractically slow, potentially getting stuck in suboptimal local minima.
Typical Range for LLMs: For large Transformer models trained with AdamW, peak learning rates often fall within the range of 1e−5 to 6e−4. The specific value depends heavily on the model size, batch size, and architecture. Larger models sometimes benefit from slightly smaller peak learning rates. For instance, a common starting point for models in the multi-billion parameter range might be around 1e−4 to 3e−4.
Interaction with Batch Size: Training LLMs typically involves very large batch sizes (often millions of tokens per batch, achieved through data and gradient accumulation). A common heuristic is the linear scaling rule: if you multiply the batch size by k, you should also multiply the learning rate by k. However, this rule doesn't always hold perfectly, especially for extremely large batch sizes. Some research suggests a square root scaling rule (η∝k) might be more appropriate beyond certain batch sizes. In practice, while batch size influences the choice, the optimal learning rate is usually found empirically within the typical ranges mentioned, often starting from values reported in papers training similarly sized models.
Interaction with Schedules: Remember that we typically use a learning rate schedule with warmup and decay (Chapter 17, "Learning Rate Scheduling Strategies"). When we talk about "the learning rate", we usually mean the peak learning rate achieved after the warmup phase. The warmup duration and decay strategy (e.g., linear or cosine) also interact with the peak learning rate's effectiveness. A longer warmup might allow for a slightly higher peak learning rate.
Finding the Optimal Value: Given the computational cost, exhaustive grid searches are often infeasible. A common strategy is to:
Start with values reported in the literature for models of comparable size and architecture (e.g., Llama, GPT-3 papers).
Perform a small sweep around this value (e.g., 1e−4,2e−4,3e−4) on a shorter training run or a smaller scale version of the model/dataset if possible.
Monitor training loss curves closely during the initial phases. Rapid drops followed by instability suggest the learning rate is too high. Extremely slow progress suggests it's too low.
Adam/AdamW Betas (β1,β2)
Adam and AdamW use two momentum-like terms, controlled by β1 and β2.
β1: Controls the exponential moving average of the gradients (first moment). A typical default is 0.9.
β2: Controls the exponential moving average of the squared gradients (second moment). A typical default is 0.999. This term adapts the learning rate per parameter.
For most LLM training scenarios, the default values of β1=0.9 and β2=0.999 work remarkably well and are often used without modification.
Tuning Considerations: While tuning these is less common than tuning the learning rate, some studies have explored alternatives.
Lowering β2 (e.g., to 0.95 or 0.98) makes the optimizer adapt more quickly to changes in the gradient variance, which can sometimes help escape sharp minima or improve stability early in training, but may also lead to less stable convergence later. Some large model training recipes use β2=0.95 or β2=0.98.
Adjusting β1 is less frequent.
It's generally recommended to stick with the defaults (β1=0.9,β2=0.999 or perhaps β2=0.98) unless you observe specific instabilities that aren't resolved by tuning the learning rate or using gradient clipping. Changing betas adds complexity to the tuning process.
Epsilon (ϵ)
The epsilon term (ϵ) in Adam/AdamW is a small value added to the denominator during the adaptive learning rate calculation (specifically, before taking the square root of the second moment estimate). Its primary purpose is to prevent division by zero and improve numerical stability, particularly when the second moment estimate is very close to zero.
Common Value: The standard default value is 1e−8.
Tuning Considerations: Epsilon is rarely tuned. It generally has a minimal impact on performance compared to the learning rate or weight decay. In scenarios involving mixed-precision training (Chapter 20), especially FP16, where numerical precision is lower, some practitioners might slightly increase epsilon (e.g., to 1e−7 or 1e−6) to further enhance numerical stability. However, this is typically a last resort if stability issues related to the optimizer denominator are suspected. Sticking with the default 1e−8 is standard practice.
Weight Decay (λ)
Weight decay is a regularization technique used to prevent overfitting by adding a penalty proportional to the squared magnitude of the model weights to the loss function. As discussed previously, AdamW implements a decoupled weight decay, which is generally preferred over the L2 regularization approach inherent in the original Adam optimizer when used with adaptive gradients.
Typical Range for LLMs: Common values for weight decay (λ) in LLM training often range from 0.01 to 0.1. A value of 0.1 is frequently used as a starting point.
Tuning Strategy: The optimal weight decay value is data and model-dependent.
It's usually tuned after finding a reasonable learning rate and schedule.
Monitor the validation loss (or perplexity). If the validation loss starts increasing while the training loss continues decreasing, the model might be overfitting, and increasing weight decay could help. Conversely, if both training and validation loss plateau too early, reducing weight decay might be beneficial.
A small search (e.g., values like 0.01,0.05,0.1) is often sufficient.
Note that for very large models trained on massive datasets, the implicit regularization from the data scale might reduce the need for strong explicit weight decay, but a value around 0.1 is still common.
Practical Implementation and Interactions
Remember that these hyperparameters interact. Changing the learning rate might necessitate adjusting the weight decay. The effectiveness of a learning rate schedule is tied to the peak learning rate chosen.
Here's how you might define an AdamW optimizer in PyTorch, specifying these hyperparameters:
import torch
# Assume 'model' is your large language model instance
# Define hyperparameters
peak_lr = 2e-4
beta1 = 0.9
beta2 = 0.98 # Example using a slightly adjusted beta2
epsilon = 1e-8
weight_decay_lambda = 0.1
# Filter parameters that should not have weight decay applied
# (e.g., biases, LayerNorm weights)
decay_params = []
no_decay_params = []
for pn, p in model.named_parameters():
if p.requires_grad:
# Check for biases or 1D parameters (like LayerNorm weights)
if (pn.endswith("bias") or
len(p.shape) == 1):
no_decay_params.append(p)
# print(f"No decay for: {pn}") # Uncomment for debugging
else:
decay_params.append(p)
# print(f"Decay for: {pn}") # Uncomment for debugging
# Create optimizer parameter groups
optimizer_grouped_parameters = [
{'params': decay_params, 'weight_decay': weight_decay_lambda},
{
'params': no_decay_params,
'weight_decay': 0.0
# No weight decay for these parameters
}
]
# Instantiate the optimizer
optimizer = torch.optim.AdamW(
optimizer_grouped_parameters,
lr=peak_lr,
# Note: LR scheduler will adjust this value during training
betas=(beta1, beta2),
eps=epsilon
)
print(
f"Optimizer created with {len(optimizer_grouped_parameters)} "
f"parameter groups."
)
print(
f"Group 0 (decay): "
f"{len(optimizer_grouped_parameters[0]['params'])} tensors, "
f"weight_decay="
f"{optimizer_grouped_parameters[0]['weight_decay']}"
)
print(
f"Group 1 (no_decay): "
f"{len(optimizer_grouped_parameters[1]['params'])} tensors, "
f"weight_decay="
f"{optimizer_grouped_parameters[1]['weight_decay']}"
)
# Example of how a scheduler would typically be used
# (scheduler setup not shown here)
# Assume scheduler is defined elsewhere based on peak_lr, warmup, decay
# for step in range(num_training_steps):
# current_lr = scheduler.get_last_lr()[0] # Get current LR from scheduler
# # Set the learning rate for the optimizer *before* the optimizer step
# for param_group in optimizer.param_groups:
# param_group['lr'] = current_lr
# # ... forward pass, loss calculation, backward pass ...
# optimizer.step()
# scheduler.step() # Update scheduler
# optimizer.zero_grad()
Example PyTorch code snippet demonstrating how to initialize the AdamW optimizer with separate parameter groups for applying weight decay selectively. Biases and normalization layer parameters often have weight decay turned off.
Key Takeaways:
Start with hyperparameters proven effective for similar LLM training setups found in research papers.
Prioritize tuning the peak learning rate and its associated schedule (warmup steps, decay type).
Tune weight decay based on validation performance to control overfitting.
Keep AdamW β1 and β2 at their defaults (0.9, 0.999) or potentially try β2=0.98 if stability issues are observed. Avoid extensive tuning unless necessary.
Leave ϵ at its default (1e−8) unless specific numerical precision issues arise, particularly in mixed-precision training.
Utilize robust logging and monitoring (Chapter 24) to track loss curves, gradient norms, and other metrics to diagnose issues related to hyperparameter choices.
Finding the optimal set of hyperparameters is an iterative process. Careful selection, informed by established practices and empirical observation through monitoring, is essential for navigating the complexities of large-scale model optimization.