Masterclass
Successfully saving a checkpoint is only half the battle; the ability to correctly resume training from that checkpoint is equally important to realize the benefits of fault tolerance. A naive resumption, perhaps only loading model weights, will lead to suboptimal training dynamics or incorrect results. True resumption requires restoring the entire training state to precisely where it left off before the interruption.
This involves loading not just the model parameters, but also the state of the optimizer, the learning rate scheduler, and potentially the data loading progress and random number generator states. Failure to restore any of these components can invalidate the subsequent training process. For instance, restarting an adaptive optimizer like AdamW without its accumulated momentum and variance estimates effectively resets its learning trajectory, potentially undoing significant progress. Similarly, restarting a learning rate schedule with warmup and decay from the beginning can drastically alter the optimization path.
Let's examine the steps involved in implementing a robust resume mechanism.
First, your training script needs logic to detect if a resume operation is requested, typically via a command-line argument or configuration setting specifying the checkpoint path. It's often practical to automatically look for the latest valid checkpoint in a designated directory.
import torch
import os
import glob
def find_latest_checkpoint(checkpoint_dir):
"""Finds the latest checkpoint file based on iteration number."""
list_of_files = glob.glob(os.path.join(checkpoint_dir, 'checkpoint_*.pt'))
if not list_of_files:
return None
# Assuming filename format 'checkpoint_iter_XXXXX.pt'
latest_file = max(
list_of_files,
key=lambda f: int(f.split('_')[-1].split('.')[0])
)
return latest_file
# --- In your main training script ---
# config.resume_from_checkpoint = True or False
# config.checkpoint_dir = '/path/to/checkpoints'
# config.resume_checkpoint_path = None # Optionally specify exact path
resume_path = None
if config.resume_from_checkpoint:
if config.resume_checkpoint_path:
resume_path = config.resume_checkpoint_path
else:
resume_path = find_latest_checkpoint(config.checkpoint_dir)
if resume_path and os.path.isfile(resume_path):
print(f"Resuming training from checkpoint: {resume_path}")
checkpoint = torch.load(resume_path, map_location='cpu') # Load to CPU first
else:
print("Starting training from scratch.")
checkpoint = None
Loading the checkpoint to the CPU first (map_location='cpu'
) before moving components to the target device can prevent GPU memory spikes, especially in multi-GPU setups.
Once the checkpoint dictionary is loaded, you need to restore the state of the core training components.
# Assume model, optimizer, and scheduler are already initialized
# (as they would be for starting training from scratch)
model = YourTransformerModel(config)
optimizer = torch.optim.AdamW(
model.parameters(), lr=config.learning_rate, ...
)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, ...) # Example scheduler
start_iter = 0
best_val_loss = float('inf')
if checkpoint is not None:
# Restore Model State
# Handle potential mismatches if architecture slightly changed (use strict=False)
# Note: For exact resumption, strict=True is preferred.
model.load_state_dict(
checkpoint['model_state_dict'], strict=True
)
print("Model state loaded.")
# Restore Optimizer State
# Crucial for adaptive optimizers and learning rate momentum
if 'optimizer_state_dict' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
print("Optimizer state loaded.")
else:
print(
"Warning: Optimizer state not found in checkpoint. "
"Starting optimizer from scratch."
)
# Restore LR Scheduler State
if 'scheduler_state_dict' in checkpoint:
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
print("Scheduler state loaded.")
else:
print(
"Warning: Scheduler state not found in checkpoint. "
"Starting scheduler from scratch."
)
# Restore Training Progress
if 'iteration' in checkpoint:
start_iter = checkpoint['iteration'] + 1 # Start from the next iteration
print(f"Resuming from iteration: {start_iter}")
if 'best_val_loss' in checkpoint:
best_val_loss = checkpoint['best_val_loss']
print(f"Loaded best validation loss: {best_val_loss:.4f}")
# Restore RNG states for reproducibility (optional but recommended)
if 'rng_states' in checkpoint:
torch.set_rng_state(
checkpoint['rng_states']['torch_rng_state']
)
# Potentially restore numpy and python random states as well
# import numpy as np
# np.random.set_state(checkpoint['rng_states']['numpy_rng_state'])
# import random
# random.setstate(checkpoint['rng_states']['python_rng_state'])
print("RNG states loaded.")
# Move model to the target device(s) AFTER loading state dict
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Important: Move optimizer states to the correct device AFTER loading
# This is handled automatically by some frameworks but needs care in raw PyTorch
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(device)
print(f"Training will start/resume from iteration {start_iter}")
A critical detail often overlooked is ensuring the optimizer's state tensors are moved to the correct device after loading the state dict. While model parameters are moved via model.to(device)
, optimizer states (like momentum buffers in AdamW) reside within the optimizer object and might need explicit device placement.
This is arguably the most intricate part of resuming. Simply restarting the data loader from the beginning of the dataset upon resumption means you will re-process data samples seen before the interruption occurred within that epoch. This skews the training data distribution for that epoch and delays progress.
Ideally, you want the data loader to pick up exactly where it left off. Strategies include:
# --- Inside your training loop, BEFORE starting the loop ---
# Assume 'train_dataloader' is initialized (potentially using a seed based on epoch)
# We need to know which iteration we are resuming INTO the current epoch
# Calculate the epoch number and iteration within the epoch
# Assumes 'config.gradient_accumulation_steps' if applicable
effective_batch_size = config.batch_size * config.num_gpus # Adjust for DP/DDP
# If using gradient accumulation, effective_batch_size doesn't change,
# but steps per epoch might. Let's assume steps are based on optimizer steps.
iterations_per_epoch = len(train_dataloader)
# Or calculate based on dataset size / effective_batch_size
start_epoch = start_iter // iterations_per_epoch
resume_iter_within_epoch = start_iter % iterations_per_epoch
print(
f"Resuming into epoch {start_epoch}, starting from iteration "
f"{resume_iter_within_epoch} within the epoch."
)
# --- Inside the epoch loop ---
for epoch in range(start_epoch, config.num_epochs):
# Re-seed dataloader sampler for reproducibility if needed
# train_dataloader.sampler.set_epoch(epoch) # If using DistributedSampler
data_iter = iter(train_dataloader)
# If resuming within this epoch, skip already processed batches
if epoch == start_epoch and resume_iter_within_epoch > 0:
print(
f"Skipping {resume_iter_within_epoch} batches in epoch {epoch} "
f"to resume state..."
)
for _ in range(resume_iter_within_epoch):
try:
next(data_iter)
except StopIteration:
# Should not happen if checkpoint logic is correct
print("Error: Tried to skip past the end of the dataloader.")
break
print("Skipping complete.")
# Now start the actual training iterations for this epoch
for step_in_epoch in range(
resume_iter_within_epoch, iterations_per_epoch
):
current_global_iter = epoch * iterations_per_epoch + step_in_epoch
# Fetch batch (handle potential StopIteration if skipping went wrong)
try:
batch = next(data_iter)
except StopIteration:
print(
f"Warning: DataLoader exhausted unexpectedly at step "
f"{step_in_epoch} in epoch {epoch}."
)
break
# ... rest of your training step: move batch to device, forward,
# backward, optimizer step ...
# Reset resume marker for subsequent epochs
if step_in_epoch == resume_iter_within_epoch:
resume_iter_within_epoch = 0
# Ensure the marker is reset if the inner loop finishes early
resume_iter_within_epoch = 0
This skipping mechanism ensures that the model sees each data sample roughly the intended number of times across the entire training run, preserving the integrity of the training process. Libraries like torch.utils.data.DataLoader
combined with samplers like DistributedSampler
often require careful handling of epoch seeding (sampler.set_epoch(epoch)
) to ensure proper data shuffling and distribution in distributed settings, especially when resuming.
In a distributed training environment (DDP, FSDP, ZeRO), resuming requires careful coordination:
torch.distributed.barrier()
to ensure all processes have located the checkpoint before loading.load_checkpoint
function typically handles loading the appropriate shard for each rank automatically. Directly using torch.load
and optimizer.load_state_dict
might not work correctly for these sharded states; rely on the framework's utilities.DistributedSampler
, ensure set_epoch()
is called correctly upon resuming and that the skipping logic accounts for the per-rank data shard. Each rank skips batches within its own portion of the data.# Example using DeepSpeed's checkpoint loading (simplified)
# Assumes 'model_engine' is the DeepSpeed engine wrapping model, optimizer, etc.
# Somewhere in your setup code:
load_path, client_state = model_engine.load_checkpoint(
config.checkpoint_dir, tag=config.checkpoint_tag
)
if load_path is not None:
print(f"Resumed training from checkpoint {load_path}")
# DeepSpeed's load_checkpoint returns client_state which might contain
# iteration count, RNG states etc., that you saved.
start_iter = client_state.get('iteration', 0) + 1
# ... restore other custom states from client_state ...
else:
print("Starting training from scratch.")
start_iter = 0
# DeepSpeed handles restoring model, optimizer,
# and scheduler states internally.
# You primarily need to restore custom states
# you added to 'client_state' during saving.
# And importantly, handle the data loader skipping
# based on the restored 'start_iter'.
After resuming, it's good practice to verify that the state was restored correctly. One simple check is to log the loss and learning rate immediately after the first resumed training step and compare them to the values logged just before the interruption (if available). They should be very close, accounting for minor floating-point variations and the effect of the next data batch. Significant deviations might indicate an issue in the resume logic. Thoroughly testing the save/resume functionality on smaller runs before launching large-scale jobs is highly recommended.
By carefully restoring the model, optimizer, scheduler, data loader position, and other metadata, you ensure that training continues seamlessly after an interruption, saving valuable time and computational resources.
© 2025 ApX Machine Learning