Hardware failure during the training of terabyte-scale models is a statistical certainty rather than a possibility. When scaling to hundreds of GPUs, the Mean Time Between Failures (MTBF) drops drastically, requiring training loops that are not just performant but resilient. Relying on manual restarts or simple epoch-based saving is insufficient for jobs that may run for weeks.
This section implements a production-grade training loop utilizing PyTorch's Distributed Checkpointing (DCP) API and TorchElastic. We focus on persisting the SHARDED_STATE_DICT, which allows each GPU to save its local partition of parameters and optimizer states directly to storage. This approach bypasses the memory bottlenecks associated with aggregating a full model on a single rank.
To achieve fault tolerance, the training script must function as an idempotent state machine. Upon startup, it checks for existing snapshots. If a snapshot exists, the system restores the model, optimizer, and learning rate scheduler to their exact states at the last recorded step. If no snapshot is found, training begins from scratch.
The following state diagram illustrates the recovery flow managed by TorchElastic when a worker fails.
The flow demonstrates how TorchElastic detects failures (octagon) and automatically re-executes the entry point, triggering the checkpoint check logic immediately.
Before implementing the save/load logic, we must configure the FSDP model to yield sharded tensors. By default, calling .state_dict() on an FSDP module might attempt to gather the full weights, leading to OOM. We use the FSDP.state_dict_type context manager to enforce sharded persistence.
We define a CheckpointManager class to encapsulate the complexity of distributed I/O. This class handles the path management and interacts with torch.distributed.checkpoint.
import os
import shutil
import torch
import torch.distributed.checkpoint as dcp
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType, FullStateDictConfig
class CheckpointManager:
def __init__(self, checkpoint_folder):
self.checkpoint_folder = checkpoint_folder
def save(self, model, optimizer, scheduler, step):
# Create a state payload containing all necessary components
# The model and optimizer must be passed as references to capture their sharded state
state_payload = {
"model": model,
"optimizer": optimizer,
"scheduler": scheduler,
"metadata": {"step": step}
}
# Configure FSDP to yield local shards only
# This prevents gathering full weights to CPU
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
# dcp.save automatically handles the parallel write from all ranks
dcp.save(
state_dict=state_payload,
storage_writer=dcp.FileSystemWriter(self.checkpoint_folder)
)
def load(self, model, optimizer, scheduler):
# Before loading, we must ensure the payload structure matches
state_payload = {
"model": model,
"optimizer": optimizer,
"scheduler": scheduler,
"metadata": {"step": 0} # Default value, will be overwritten
}
# Check if checkpoint exists
if not os.path.exists(self.checkpoint_folder):
return 0 # Start from step 0
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
dcp.load(
state_dict=state_payload,
storage_reader=dcp.FileSystemReader(self.checkpoint_folder)
)
return state_payload["metadata"]["step"]
A frequent issue in distributed systems is corruption caused by a crash occurring during the write operation. If a job is preempted while writing checkpoint_1000, the folder may contain partial data. Upon restart, loading this corrupted checkpoint will crash the training loop again, creating a failure loop.
To mitigate this, we implement an atomic save strategy. We write to a temporary directory first and verify success before renaming it to the permanent checkpoint path. Since directory renames are atomic on POSIX filesystems, the checkpoint is either fully present or non-existent.
Here is the revised training loop integrating the CheckpointManager and atomic logic:
def train_loop(rank, model, optimizer, scheduler, train_loader):
# Initialize manager
ckpt_dir = "checkpoints/latest"
manager = CheckpointManager(ckpt_dir)
# Attempt to resume
# This must happen AFTER model wrapping and optimizer creation
start_step = manager.load(model, optimizer, scheduler)
if start_step > 0 and rank == 0:
print(f"Resuming training from step {start_step}")
# Fast-forward data loader if necessary (omitted for brevity)
# In practice, use a StatefulDataLoader to restore iterator position
model.train()
for step, batch in enumerate(train_loader, start=start_step):
inputs, targets = batch[0].to(rank), batch[1].to(rank)
optimizer.zero_grad()
output = model(inputs)
loss = torch.nn.functional.cross_entropy(output, targets)
loss.backward()
optimizer.step()
scheduler.step()
# Save checkpoint every 500 steps
if step > 0 and step % 500 == 0:
# Use a temp path for atomicity
tmp_path = f"checkpoints/tmp_{step}"
# 1. Write to temporary location
# All ranks participate in the write
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
dcp.save(
state_dict={
"model": model,
"optimizer": optimizer,
"scheduler": scheduler,
"metadata": {"step": step}
},
storage_writer=dcp.FileSystemWriter(tmp_path)
)
# 2. Atomic swap (Coordinator only)
# Wait for all ranks to finish writing before renaming
torch.distributed.barrier()
if rank == 0:
if os.path.exists(ckpt_dir):
shutil.rmtree(ckpt_dir)
os.rename(tmp_path, ckpt_dir)
print(f"Checkpoint saved atomically at step {step}")
# Ensure all ranks see the new directory structure before proceeding
torch.distributed.barrier()
When TorchElastic restarts a job, the cluster topology might change. For instance, if a node fails permanently, the job might restart with fewer nodes (if configured for elastic scaling) or wait for a replacement node.
FSDP sharding depends on the world_size. If you train on 8 GPUs, the model is sharded 8 ways. If you restart on 16 GPUs, the sharding pattern changes. Standard torch.load fails here because the tensor shapes do not match.
However, torch.distributed.checkpoint handles this re-sharding automatically. It saves the tensors in a layout-agnostic format. During loading, DCP redistributes the weights according to the current world_size and FSDP config. This capability allows you to train on 64 GPUs, checkpoint, and resume debugging on 8 GPUs without manual conversion scripts.
Writing terabytes of sharded data can saturate storage bandwidth. The chart below analyzes the time cost of checkpointing relative to compute time as model size increases.
As model size grows, I/O latency (blue bars) becomes significant. While compute time (red line) scales linearly with efficient parallelization, storage throughput often hits a hard ceiling.
To minimize this blocking time, use Asynchronous Checkpointing. PyTorch provides async_save within the DCP module. This spawns a background thread to handle the I/O writes, allowing the GPU to immediately return to the training loop.
# Enabling async save to hide I/O latency
# Note: Requires careful memory management as CPU memory increases temporarily
dcp.async_save(
state_dict=state_payload,
storage_writer=dcp.FileSystemWriter(ckpt_dir, thread_count=4)
)
When implementing async saving, monitor CPU RAM usage. The state dictionary remains pinned in host memory until the write completes. On nodes with limited CPU RAM, overlapping the copy-to-host and write-to-disk phases may trigger OOM errors if the previous checkpoint has not finished writing before the next one begins.
Was this section helpful?
© 2026 ApX Machine LearningEngineered with