Masterclass
Saving the state of a single training process is straightforward, as discussed in the previous section. However, large language model training almost invariably involves multiple compute nodes and devices working in parallel. This distributed nature introduces significant complexity to the checkpointing process. Simply having each worker save its state independently is insufficient; coordination is required to ensure the collective saved state represents a valid and consistent snapshot of the entire training job. Without this coordination, resuming training could lead to divergent behavior or incorrect results.
The primary challenge stems from the need for consistency. All participating processes (often referred to as ranks) must save their portion of the training state corresponding to the same point in the computation, typically at the end of a specific training step. If different ranks save at slightly different times, perhaps one rank finishing its gradient update while another is still calculating gradients, the resulting checkpoint would be inconsistent and likely unusable.
The most fundamental technique to ensure consistency is synchronization. Before initiating the save operation, all ranks must synchronize to guarantee they have reached the same logical point in the training loop. In frameworks like PyTorch Distributed Data Parallel (DDP), this is often achieved using collective communication operations like barriers.
import torch
import torch.distributed as dist
import os
# Assume setup_distributed() initializes the process group
# setup_distributed()
def save_checkpoint_distributed(
model, optimizer, scheduler, epoch, step, checkpoint_dir
):
"""Saves a checkpoint, coordinated across ranks."""
# Ensure all ranks are ready to save
dist.barrier()
# Designate one rank (usually rank 0) to handle non-sharded saving
if dist.get_rank() == 0:
# Create checkpoint directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)
# Prepare the state dictionary
# Note: For DDP, model.module should be saved to strip the DDP wrapper
state = {
'epoch': epoch,
'step': step,
'model_state_dict': model.module.state_dict(), # Save the underlying model
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
# Add any other necessary states (e.g., RNG states, dataloader state)
}
# Define the checkpoint path
checkpoint_filename = f"checkpoint_epoch_{epoch}_step_{step}.pt"
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_filename)
# Save the state dictionary
torch.save(state, checkpoint_path)
print(f"Rank 0: Saved checkpoint to {checkpoint_path}")
# Ensure all ranks wait until rank 0 finishes saving before proceeding
dist.barrier()
# Example usage within a training loop (simplified)
# model = ... # Your DDP wrapped model
# optimizer = ...
# scheduler = ...
# checkpoint_dir = "/path/to/checkpoints"
# current_epoch = 1
# current_step = 5000
# save_checkpoint_distributed(
# model, optimizer, scheduler, current_epoch, current_step, checkpoint_dir
# )
In the example above, dist.barrier()
acts as a synchronization point. The first barrier ensures all ranks pause before rank 0 starts saving. Rank 0 then saves the necessary state dictionaries. Critically, for a model wrapped with DDP, we save model.module.state_dict()
to store the parameters of the original model, not the DDP wrapper itself. The second barrier ensures that no rank proceeds with the next training step until rank 0 has successfully completed the save operation. This prevents race conditions where some ranks might start modifying the state while it's being saved.
While this rank 0 saving approach works, it has limitations, especially at scale. Gathering the entire model state, optimizer state, and potentially large gradients onto a single rank can create a network bottleneck and require significant memory on rank 0. Furthermore, the saving process itself becomes serialized through rank 0.
A more scalable approach, particularly relevant when using memory optimization techniques like ZeRO (Zero Redundancy Optimizer) or tensor/pipeline parallelism, is to save sharded checkpoints. In a sharded checkpoint, each rank saves only its portion of the overall training state.
Libraries like DeepSpeed and Megatron-LM provide higher-level APIs that abstract away much of the complexity of managing sharded checkpoints. They handle the synchronization and ensure that each rank saves the correct state corresponding to its role in the parallelism configuration.
# Example using a DeepSpeed-like API (actual API may vary)
# Assume 'model_engine' is the DeepSpeed-wrapped model, optimizer, etc.
# DeepSpeed often uses a tag/label for checkpoints
checkpoint_tag = f"epoch_{epoch}_step_{step}"
checkpoint_dir = "/path/to/sharded/checkpoints"
# DeepSpeed's save_checkpoint handles sharding and synchronization internally
# It saves model state, optimizer state, scheduler state, etc.
# Each rank writes its own shard(s) to the directory.
save_status = model_engine.save_checkpoint(checkpoint_dir, checkpoint_tag)
if save_status:
print(
f"Rank {dist.get_rank()}: Successfully saved "
f"sharded checkpoint {checkpoint_tag}"
)
else:
print(f"Rank {dist.get_rank()}: Failed to save sharded checkpoint {checkpoint_tag}")
# No explicit barrier needed here, as it's managed by the DeepSpeed function.
When using sharded checkpoints, the checkpoint_dir
will contain multiple files, each representing a shard from a different rank or containing metadata. The loading process must also be aware of this sharded structure to correctly reconstruct the global state across all ranks.
Loading a checkpoint in a distributed setting also requires coordination.
state_dict
on all ranks, but optimizer states might need manual handling depending on the setup.# Example loading for the rank 0 saving approach
def load_checkpoint_distributed(model, optimizer, scheduler, checkpoint_path):
"""Loads a checkpoint saved by rank 0 onto all ranks."""
# Ensure the checkpoint path exists
if not os.path.exists(checkpoint_path):
print(f"Warning: Checkpoint path {checkpoint_path} does not exist. "
f"Starting from scratch.")
return 0, 0 # Return starting epoch/step
# Load the checkpoint onto the CPU first to avoid GPU memory spike on rank 0
rank_device = 'cuda:%d' % dist.get_rank()
map_location = {'cuda:%d' % 0: rank_device} # Map to current rank's device
checkpoint = torch.load(checkpoint_path, map_location=map_location)
# Load model state (remember to use model.module for DDP)
model.module.load_state_dict(checkpoint['model_state_dict'])
# Load optimizer and scheduler states
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
# Load epoch and step
epoch = checkpoint['epoch']
step = checkpoint['step']
print(f"Rank {dist.get_rank()}: Loaded checkpoint from "
f"{checkpoint_path} (Epoch {epoch}, Step {step})")
# Ensure all ranks have loaded before proceeding
dist.barrier()
return epoch, step
# checkpoint_to_load = "/path/to/checkpoints/checkpoint_epoch_1_step_5000.pt"
# start_epoch, start_step = load_checkpoint_distributed(model, optimizer, scheduler, checkpoint_to_load)
model_engine.load_checkpoint
) handles reading the appropriate shards for each rank and reconstructing the distributed state. The process is significantly simpler from the user's perspective as the library manages the complexity.# Example using a DeepSpeed-like API
# checkpoint_dir = "/path/to/sharded/checkpoints"
# checkpoint_tag = f"epoch_{epoch}_step_{step}" # The tag used during saving
# DeepSpeed's load_checkpoint handles reading shards and distributing state
load_path, client_state = model_engine.load_checkpoint(
checkpoint_dir, checkpoint_tag
)
if load_path:
print(f"Rank {dist.get_rank()}: Successfully loaded sharded "
f"checkpoint {checkpoint_tag} from {load_path}")
# client_state often contains epoch, step, etc.
start_epoch = client_state.get('epoch', 0)
start_step = client_state.get('step', 0)
else:
print(f"Rank {dist.get_rank()}: Could not find checkpoint "
f"{checkpoint_tag}, starting from scratch.")
start_epoch, start_step = 0, 0
Handling distributed checkpointing correctly is fundamental for reliable large-scale model training. While manual implementation requires careful synchronization and state management, leveraging features within distributed training frameworks like DeepSpeed or Megatron-LM often provides a more robust and scalable solution by automating sharding and synchronization. Remember to always test your checkpoint saving and loading procedures thoroughly to ensure they function correctly in your specific distributed setup.
© 2025 ApX Machine Learning