Hardware failure during the training of terabyte-scale models is a statistical certainty rather than a possibility. When scaling to hundreds of GPUs, the Mean Time Between Failures (MTBF) drops drastically, requiring training loops that are not just performant but resilient. Relying on manual restarts or simple epoch-based saving is insufficient for jobs that may run for weeks.This section implements a production-grade training loop utilizing PyTorch's Distributed Checkpointing (DCP) API and TorchElastic. We focus on persisting the SHARDED_STATE_DICT, which allows each GPU to save its local partition of parameters and optimizer states directly to storage. This approach bypasses the memory bottlenecks associated with aggregating a full model on a single rank.The Resumable Training LifecycleTo achieve fault tolerance, the training script must function as an idempotent state machine. Upon startup, it checks for existing snapshots. If a snapshot exists, the system restores the model, optimizer, and learning rate scheduler to their exact states at the last recorded step. If no snapshot is found, training begins from scratch.The following state diagram illustrates the recovery flow managed by TorchElastic when a worker fails.digraph TrainingLifecycle { rankdir=TB; node [shape=box, style=filled, fontname="Helvetica", fontsize=12, color="#dee2e6"]; edge [fontname="Helvetica", fontsize=10, color="#868e96"]; Start [label="Process Start (torchrun)", fillcolor="#e7f5ff", color="#74c0fc"]; CheckSnap [label="Check for Checkpoint", fillcolor="#fff9db", color="#ffe066"]; LoadSnap [label="Load Sharded State\n(DCP)", fillcolor="#d3f9d8", color="#69db7c"]; InitModel [label="Initialize Random Weights", fillcolor="#fff9db", color="#ffe066"]; TrainLoop [label="Training Step N", fillcolor="#e7f5ff", color="#74c0fc"]; SaveSnap [label="Save Sharded State\n(Atomic Write)", fillcolor="#d3f9d8", color="#69db7c"]; Failure [label="Hardware/Process Failure", shape=octagon, fillcolor="#ffc9c9", color="#ff8787"]; Start -> CheckSnap; CheckSnap -> LoadSnap [label="Found"]; CheckSnap -> InitModel [label="Not Found"]; LoadSnap -> TrainLoop; InitModel -> TrainLoop; TrainLoop -> TrainLoop [label="Step++"]; TrainLoop -> SaveSnap [label="Interval Reached"]; SaveSnap -> TrainLoop; TrainLoop -> Failure [style=dashed]; SaveSnap -> Failure [style=dashed]; Failure -> Start [label="Elastic Restart", style=dotted]; }The flow demonstrates how TorchElastic detects failures (octagon) and automatically re-executes the entry point, triggering the checkpoint check logic immediately.Configuring Sharded State DictionariesBefore implementing the save/load logic, we must configure the FSDP model to yield sharded tensors. By default, calling .state_dict() on an FSDP module might attempt to gather the full weights, leading to OOM. We use the FSDP.state_dict_type context manager to enforce sharded persistence.We define a CheckpointManager class to encapsulate the complexity of distributed I/O. This class handles the path management and interacts with torch.distributed.checkpoint.import os import shutil import torch import torch.distributed.checkpoint as dcp from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType, FullStateDictConfig class CheckpointManager: def __init__(self, checkpoint_folder): self.checkpoint_folder = checkpoint_folder def save(self, model, optimizer, scheduler, step): # Create a state payload containing all necessary components # The model and optimizer must be passed as references to capture their sharded state state_payload = { "model": model, "optimizer": optimizer, "scheduler": scheduler, "metadata": {"step": step} } # Configure FSDP to yield local shards only # This prevents gathering full weights to CPU with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): # dcp.save automatically handles the parallel write from all ranks dcp.save( state_dict=state_payload, storage_writer=dcp.FileSystemWriter(self.checkpoint_folder) ) def load(self, model, optimizer, scheduler): # Before loading, we must ensure the payload structure matches state_payload = { "model": model, "optimizer": optimizer, "scheduler": scheduler, "metadata": {"step": 0} # Default value, will be overwritten } # Check if checkpoint exists if not os.path.exists(self.checkpoint_folder): return 0 # Start from step 0 with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): dcp.load( state_dict=state_payload, storage_reader=dcp.FileSystemReader(self.checkpoint_folder) ) return state_payload["metadata"]["step"]Atomic CheckpointingA frequent issue in distributed systems is corruption caused by a crash occurring during the write operation. If a job is preempted while writing checkpoint_1000, the folder may contain partial data. Upon restart, loading this corrupted checkpoint will crash the training loop again, creating a failure loop.To mitigate this, we implement an atomic save strategy. We write to a temporary directory first and verify success before renaming it to the permanent checkpoint path. Since directory renames are atomic on POSIX filesystems, the checkpoint is either fully present or non-existent.Here is the revised training loop integrating the CheckpointManager and atomic logic:def train_loop(rank, model, optimizer, scheduler, train_loader): # Initialize manager ckpt_dir = "checkpoints/latest" manager = CheckpointManager(ckpt_dir) # Attempt to resume # This must happen AFTER model wrapping and optimizer creation start_step = manager.load(model, optimizer, scheduler) if start_step > 0 and rank == 0: print(f"Resuming training from step {start_step}") # Fast-forward data loader if necessary (omitted for brevity) # In practice, use a StatefulDataLoader to restore iterator position model.train() for step, batch in enumerate(train_loader, start=start_step): inputs, targets = batch[0].to(rank), batch[1].to(rank) optimizer.zero_grad() output = model(inputs) loss = torch.nn.functional.cross_entropy(output, targets) loss.backward() optimizer.step() scheduler.step() # Save checkpoint every 500 steps if step > 0 and step % 500 == 0: # Use a temp path for atomicity tmp_path = f"checkpoints/tmp_{step}" # 1. Write to temporary location # All ranks participate in the write with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): dcp.save( state_dict={ "model": model, "optimizer": optimizer, "scheduler": scheduler, "metadata": {"step": step} }, storage_writer=dcp.FileSystemWriter(tmp_path) ) # 2. Atomic swap (Coordinator only) # Wait for all ranks to finish writing before renaming torch.distributed.barrier() if rank == 0: if os.path.exists(ckpt_dir): shutil.rmtree(ckpt_dir) os.rename(tmp_path, ckpt_dir) print(f"Checkpoint saved atomically at step {step}") # Ensure all ranks see the new directory structure before proceeding torch.distributed.barrier()Handling Topology ChangesWhen TorchElastic restarts a job, the cluster topology might change. For instance, if a node fails permanently, the job might restart with fewer nodes (if configured for elastic scaling) or wait for a replacement node.FSDP sharding depends on the world_size. If you train on 8 GPUs, the model is sharded 8 ways. If you restart on 16 GPUs, the sharding pattern changes. Standard torch.load fails here because the tensor shapes do not match.However, torch.distributed.checkpoint handles this re-sharding automatically. It saves the tensors in a layout-agnostic format. During loading, DCP redistributes the weights according to the current world_size and FSDP config. This capability allows you to train on 64 GPUs, checkpoint, and resume debugging on 8 GPUs without manual conversion scripts.Performance Implications of I/OWriting terabytes of sharded data can saturate storage bandwidth. The chart below analyzes the time cost of checkpointing relative to compute time as model size increases.{ "data": [ { "x": ["7B", "13B", "70B", "175B"], "y": [12, 24, 115, 340], "name": "Checkpoint Time (s)", "type": "bar", "marker": {"color": "#74c0fc"} }, { "x": ["7B", "13B", "70B", "175B"], "y": [120, 180, 450, 900], "name": "Training Step Time (s)", "type": "scatter", "mode": "lines+markers", "yaxis": "y2", "line": {"color": "#fa5252"} } ], "layout": { "title": "Checkpoint Overhead vs. Model Size (HDD Storage)", "xaxis": {"title": "Model Parameter Size"}, "yaxis": {"title": "Save Time (seconds)", "side": "left"}, "yaxis2": {"title": "Compute Time per 100 steps (seconds)", "overlaying": "y", "side": "right"}, "legend": {"x": 0.1, "y": 1.1, "orientation": "h"}, "margin": {"l": 50, "r": 50, "t": 50, "b": 50} } }As model size grows, I/O latency (blue bars) becomes significant. While compute time (red line) scales linearly with efficient parallelization, storage throughput often hits a hard ceiling.To minimize this blocking time, use Asynchronous Checkpointing. PyTorch provides async_save within the DCP module. This spawns a background thread to handle the I/O writes, allowing the GPU to immediately return to the training loop.# Enabling async save to hide I/O latency # Note: Requires careful memory management as CPU memory increases temporarily dcp.async_save( state_dict=state_payload, storage_writer=dcp.FileSystemWriter(ckpt_dir, thread_count=4) )When implementing async saving, monitor CPU RAM usage. The state dictionary remains pinned in host memory until the write completes. On nodes with limited CPU RAM, overlapping the copy-to-host and write-to-disk phases may trigger OOM errors if the previous checkpoint has not finished writing before the next one begins.