Masterclass
To ensure a training run can be seamlessly resumed after an interruption, it's not enough to just save the model parameters. A complete training state includes several components that capture the exact point where training left off. Failing to save and restore any of these can lead to suboptimal convergence, difficulty in reproducing results, or incorrect continuation of the training process. Let's examine the essential pieces of state you need to capture.
This is the most fundamental part of the checkpoint. The parameters, often referred to as weights and biases, define the learned function of the neural network. In PyTorch, the standard way to access these is through the state_dict()
. This dictionary maps each layer or buffer name to its corresponding tensor. Saving the model's state_dict
ensures that when you resume, the model starts with the exact learned representations it had achieved before the interruption.
# Assuming 'model' is your PyTorch nn.Module instance
model_state = model.state_dict()
# Example: Saving the model state
# torch.save(model_state, 'model_checkpoint.pt')
# Example: Loading the model state
# loaded_state = torch.load('model_checkpoint.pt')
# model.load_state_dict(loaded_state)
For large models, the state_dict
itself can be very large, potentially hundreds of gigabytes or even terabytes. Handling these large files requires careful consideration of storage and I/O efficiency, especially in distributed settings, which we'll discuss later.
Modern optimizers, particularly adaptive ones like Adam or AdamW commonly used for training large language models, maintain internal states beyond just the hyperparameters (like learning rate or weight decay). For instance, Adam maintains estimates of the first moment (mean) and second moment (uncentered variance) of the gradients for each parameter.
mt​vt​​=β1​mt−1​+(1−β1​)gt​=β2​vt−1​+(1−β2​)gt2​​Here, mt​ and vt​ represent the moving averages for a parameter at timestep t, based on the gradient gt​ and decay factors β1​ and β2​. These moment estimates are specific to the point in the training trajectory. If you only restore the model weights but reinitialize the optimizer, these historical gradient statistics are lost. The optimizer effectively starts from scratch, which can significantly perturb the training dynamics, potentially slowing down convergence or leading the model to a different, possibly worse, local minimum. Therefore, saving the optimizer's state is essential for smooth resumption.
Similar to models, PyTorch optimizers provide a state_dict()
method.
# Assuming 'optimizer' is your PyTorch optimizer instance
# (e.g., torch.optim.AdamW)
optimizer_state = optimizer.state_dict()
# Example: Saving the optimizer state
# torch.save(optimizer_state, 'optimizer_checkpoint.pt')
# Example: Loading the optimizer state
# loaded_state = torch.load('optimizer_checkpoint.pt')
# optimizer.load_state_dict(loaded_state)
The optimizer state includes not only the moment estimates (for optimizers like Adam) but also internal step counts and potentially other hyperparameters managed by the optimizer instance.
Large language model training almost universally employs learning rate schedules. Common strategies involve a warmup phase where the learning rate gradually increases, followed by a decay phase (e.g., linear, cosine, or polynomial decay). These schedules are critical for stable training and achieving good performance.
The scheduler's behavior depends on the current training progress, typically measured in steps or epochs. To resume the learning rate schedule correctly, you must save its internal state. This might include the number of steps taken so far, the last computed learning rate, or other internal counters used by the specific scheduler logic.
# Assuming 'scheduler' is your PyTorch LR scheduler instance
# (e.g., torch.optim.lr_scheduler.LambdaLR)
# Make sure to step the scheduler appropriately during training
# (e.g., scheduler.step())
scheduler_state = scheduler.state_dict()
# Example: Saving the scheduler state
# torch.save(scheduler_state, 'scheduler_checkpoint.pt')
# Example: Loading the scheduler state
# loaded_state = torch.load('scheduler_checkpoint.pt')
# scheduler.load_state_dict(loaded_state)
Restoring the scheduler state ensures that the learning rate continues its prescribed trajectory from the point of interruption, rather than restarting the warmup or decay phase inappropriately.
Beyond the core model and optimization components, you need to track the overall progress of the training run. This typically includes:
Saving these counters allows you to know exactly where in the training curriculum and dataset iteration to resume.
For strict reproducibility, especially in research settings or when debugging, saving the state of the random number generators used (e.g., Python's random
, NumPy's numpy.random
, and PyTorch's torch.cuda.manual_seed_all
states) is important. This ensures that data shuffling, dropout patterns, and any other stochastic elements in the training pipeline remain consistent upon resumption.
import torch
import random
import numpy as np
# Example: Saving RNG states
rng_states = {
'python_rng_state': random.getstate(),
'numpy_rng_state': np.random.get_state(),
'torch_rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state_all() # Saves state for all GPUs
}
# torch.save(rng_states, 'rng_checkpoint.pt')
# Example: Loading RNG states
# loaded_rng_states = torch.load('rng_checkpoint.pt')
# random.setstate(loaded_rng_states['python_rng_state'])
# np.random.set_state(loaded_rng_states['numpy_rng_state'])
# torch.set_rng_state(loaded_rng_states['torch_rng_state'])
# torch.cuda.set_rng_state_all(loaded_rng_states['cuda_rng_state'])
In practice, all these state components are typically saved together in a single dictionary or structured file format. This makes managing and loading checkpoints simpler.
# Assume model, optimizer, scheduler, global_step, current_epoch exist
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'global_step': global_step,
'epoch': current_epoch,
# Add RNG states if needed for strict reproducibility
'rng_states': {
'python_rng_state': random.getstate(),
'numpy_rng_state': np.random.get_state(),
'torch_rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state_all()
}
# Optionally include other metadata: loss, accuracy, framework versions, etc.
'loss': current_loss
}
# Save the consolidated checkpoint
checkpoint_path = f"checkpoint_step_{global_step}.pt"
# torch.save(checkpoint, checkpoint_path)
# --- Later, when resuming ---
# loaded_checkpoint = torch.load(checkpoint_path)
# model.load_state_dict(
# loaded_checkpoint['model_state_dict'])
# optimizer.load_state_dict(
# loaded_checkpoint['optimizer_state_dict'])
# scheduler.load_state_dict(
# loaded_checkpoint['scheduler_state_dict'])
# global_step = loaded_checkpoint['global_step']
# current_epoch = loaded_checkpoint['epoch']
# current_loss = loaded_checkpoint.get('loss', None) # Handle optional keys
# Restore RNG states if saved
# rng_states = loaded_checkpoint.get('rng_states')
# if rng_states:
# random.setstate(rng_states['python_rng_state'])
# np.random.set_state(rng_states['numpy_rng_state'])
# torch.set_rng_state(rng_states['torch_rng_state'])
# torch.cuda.set_rng_state_all(rng_states['cuda_rng_state'])
# Now you can continue the training loop
# from the restored state
By diligently saving these components, you establish a robust foundation for fault tolerance. When an interruption occurs, you can confidently restore the complete training context and continue the process with minimal disruption and wasted computation. The next sections will explore how to manage this process effectively in distributed environments and discuss strategies for optimizing checkpoint frequency and storage.
© 2025 ApX Machine Learning