Training a model for days or weeks across a large cluster of machines makes hardware or software failures an operational inevitability. A single node failure, a network partition, or a spot instance preemption can obliterate thousands of GPU-hours of computation. Therefore, building fault tolerance directly into the training workflow is not an optional enhancement but a core architectural requirement for large-scale AI. The principal technique for achieving this resilience is systematic checkpointing.
A checkpoint is a complete, persistent snapshot of a training job's state, allowing it to be resumed from the exact point of failure. Saving only the model weights is insufficient for a seamless recovery. A comprehensive checkpoint must include:
An effective checkpointing strategy balances the overhead of saving state against the potential loss of computation.
The central trade-off is frequency. Checkpointing too often introduces significant I/O overhead, as writing gigabytes of data to a remote store can stall training. Checkpointing too infrequently increases the amount of work lost in a failure. A balanced approach often involves triggering checkpoints based on a fixed interval of time (e.g., every 60 minutes) or a set number of training steps. The optimal frequency depends on the stability of your infrastructure and the cost of your compute resources.
The choice of storage backend is significant for performance and reliability. While Chapter 1 detailed various storage systems, their use in checkpointing has specific trade-offs:
Modern distributed training frameworks provide built-in support for managing the complexities of saving a sharded state.
Horovod is framework-agnostic and delegates checkpointing logic to the user. The standard implementation pattern is to designate a single worker, typically rank == 0, to handle the save operation. This prevents a "thundering herd" problem where all workers attempt to write to the same location, causing write contention and potential file corruption.
# A common checkpointing pattern in a Horovod training script
import torch
import horovod.torch as hvd
import os
# Initialize Horovod
hvd.init()
# ... model, optimizer, and other initializations
def save_checkpoint(model, optimizer, step):
# Only the primary worker (rank 0) saves the checkpoint.
if hvd.rank() == 0:
state = {
'step': step,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
# ... include other state like scheduler ...
}
# Best practice: save to a temporary file and then perform an atomic rename.
# This prevents resuming from a partially written, corrupt checkpoint.
tmp_path = "/path/to/durable/storage/checkpoint_step_{}.tmp".format(step)
final_path = "/path/to/durable/storage/checkpoint_step_{}.pt".format(step)
torch.save(state, tmp_path)
os.rename(tmp_path, final_path)
print(f"Checkpoint saved to {final_path}")
# In training loop:
# if step % config.save_interval == 0:
# save_checkpoint(model, optimizer, step)
Frameworks like DeepSpeed and PyTorch's Fully Sharded Data Parallel (FSDP) are aware of how model and optimizer states are sharded across devices. They provide high-level APIs that abstract away the complexity of gathering and saving this distributed state.
model_engine.save_checkpoint(). It automatically handles the serialization of sharded model parameters, optimizer states, and other training components into a designated directory.torch.distributed.checkpoint module. These are designed to save and load sharded tensors directly to storage without first gathering the full model onto a single GPU's memory, which is a critical capability for training models that are too large to fit on one device.Saving a checkpoint is only half the solution. A production-grade system must automate the recovery process. This responsibility falls to the workload orchestrator, such as a Kubernetes Job Controller or a Slurm scheduler.
The automated recovery workflow involves several steps:
An automated recovery loop. When a pod fails, the orchestrator finds the last successful checkpoint in object storage and launches a new pod, instructing it to resume from that state.
For highly optimized environments, more advanced patterns are common.
Cloud providers typically provide a short warning (e.g., 30-120 seconds) before terminating a spot instance. A well-designed training application can trap this signal. A background process can poll the instance metadata service for a termination notice. When a notice is received, it triggers an emergency checkpoint, ensuring minimal work is lost. This makes volatile, low-cost spot instances a highly viable option for long-running training jobs.
To minimize the training stall caused by writing large checkpoints to remote storage, you can use an asynchronous pattern. The training process performs a fast save to a local SSD, allowing computation to resume almost immediately. A separate background thread or process is then responsible for uploading the checkpoint from the local disk to durable object storage. This decouples the training loop from the high-latency network I/O, improving computational efficiency.
Was this section helpful?
© 2026 ApX Machine LearningEngineered with