At the scale of terabyte-sized models, hardware reliability transitions from a background concern to a primary engineering constraint. As the number of GPUs (N) increases, the probability of a successful training run without interruption decreases exponentially. If a single GPU has a probability p of failing within a specific timeframe, the probability of the entire cluster remaining stable is (1−p)N. In a cluster where N=1024 and p=0.001 (0.1%), the chance of completing that timeframe without failure is approximately 36%.
Training jobs that crash completely upon the loss of a single rank are unsustainable for large language models. PyTorch addresses this through TorchElastic, a component now integrated into the core library that manages worker lifecycles. While Distributed Data Parallel (DDP) or FSDP handles the gradient synchronization, TorchElastic handles the process orchestration. It provides the capability to detect worker failures, pause the remaining healthy workers, reorganize the process group, and respawn the failed processes to resume training.
Standard distributed training relies on a static definition. Every rank knows exactly how many peers exist and their addresses at initialization. If rank 5 fails, rank 0 waits indefinitely for a signal that will never arrive, causing a timeout (hang).
TorchElastic introduces an indirection layer between the cluster manager (like Slurm or Kubernetes) and the PyTorch training script. This layer consists of an Elastic Agent running on each node. These agents coordinate through a Rendezvous backend to establish the group_world.
The flow of operations during a failure event follows a specific state transition:
run_id as invalid.world_size and rank assignments.RANK, WORLD_SIZE, MASTER_ADDR).This architecture demands that the training script be idempotent regarding initialization. Since the script effectively runs from line zero again after a failure, it must be able to detect existing checkpoints and resume rather than overwriting them.
The interaction between local Elastic Agents and the global Rendezvous backend during a failure event. The agent on Node 2 reports the failure, prompting the Rendezvous system to instruct Node 1 to tear down and prepare for re-initialization.
The entry point for elastic training is torchrun (formerly python -m torch.distributed.launch). This CLI tool sets up the environment variables required for FSDP to initialize the process group correctly.
In a non-elastic setup, you might define MASTER_ADDR and MASTER_PORT manually. With torchrun, you rely on the Rendezvous backend. For high-performance clusters, the c10d backend is preferred over etcd as it runs directly on the training nodes without external dependencies.
A typical command for a multi-node FSDP job looks like this:
torchrun \
--nnodes=4 \
--nproc_per_node=8 \
--rdzv_id=job_101 \
--rdzv_backend=c10d \
--rdzv_endpoint=node-01.internal:29500 \
train_fsdp.py
The rdzv_id acts as a unique session identifier. If a node fails and restarts, it must use the same rdzv_id to rejoin the ongoing training cluster. The nnodes argument can also specify a range (e.g., 3:4), allowing the job to continue even if one node is permanently lost, assuming the batch size and gradient accumulation steps are adjusted dynamically.
To support the restart mechanics described above, your training code requires specific structural patterns. FSDP does not automatically persist state; you must implement the save/load logic.
When a failure occurs, torchrun terminates all processes and restarts the script from the beginning. Therefore, the script initialization phase must check for the existence of a checkpoint.
We use the term Snapshot to refer to the comprehensive state required to resume training, which includes the model weights, optimizer state, scheduler state, and current epoch/step counters.
Using the Distributed Checkpointing (DCP) API discussed in the previous section is key here. Standard torch.save usually requires gathering all weights to rank 0, which causes memory spikes that can crash the very recovery process you are trying to implement. DCP saves sharded states, allowing each rank to write in parallel.
Here is the logic flow required in the main function:
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
import torch.distributed as dist
import os
def load_snapshot(model, optimizer, path):
# Check if snapshot exists at path
if not os.path.exists(path):
return 0 # Start from epoch 0
# Using DCP to load sharded state
# The model and optimizer must be already initialized (sharded)
state_dict = {
"model": model,
"optimizer": optimizer
}
# DCP handles the mapping of sharded weights on disk
# to the current sharding strategy in memory
dist.checkpoint.load(
state_dict=state_dict,
checkpoint_id=path
)
# Retrieve metadata (step/epoch) separately or included in state_dict
# ... implementation details ...
print(f"Resuming from snapshot: {path}")
return loaded_epoch
def train(model, optimizer):
# Initialize FSDP Process Group
dist.init_process_group(backend="nccl")
# FSDP wrapping and initialization
# ...
# Attempt to load snapshot
start_epoch = load_snapshot(model, optimizer, "checkpoints/latest")
for epoch in range(start_epoch, TOTAL_EPOCHS):
# Training loop
# ...
# Save snapshot at end of epoch or every N steps
if dist.get_rank() == 0 or snapshot_all_ranks:
save_snapshot(model, optimizer, "checkpoints/latest")
A complex edge case in elastic training arises when the cluster size changes. Suppose you start training on 4 nodes (32 GPUs) and one node suffers a catastrophic failure. You might decide to resume training on just 3 nodes (24 GPUs) rather than waiting for hardware replacement.
In a standard FSDP setup, the model parameters are sharded across the size. Rank 0 in a 32-GPU setup holds 321 of the parameters. In a 24-GPU setup, Rank 0 must hold 241.
If you used torch.save(model.state_dict()) (which saves unsharded full weights), resumption is straightforward but memory-inefficient. If you saved sharded checkpoints (e.g., ShardedStateDict), the number of shards on disk corresponds to the previous size.
The torch.distributed.checkpoint (DCP) module solves this by decoupling the stored data structure from the runtime sharding strategy. When loading a DCP checkpoint:
world_size has changed.This capability transforms FSDP from a rigid parallelization scheme into a flexible distributed system capable of adapting to volatile infrastructure.
Determining how often to checkpoint involves a trade-off between I/O overhead and wasted compute time upon failure. We can model the cost of a failure Cfail as:
Cfail=Trestart+Trecompute
Where Trestart is the time to reload the model and Trecompute is the time lost since the last checkpoint. To minimize the expected wasted time, the optimal checkpoint interval τ can be approximated using Young's approximation, modified for distributed systems:
τ≈2×δ×MTBF
Where δ is the time taken to write a checkpoint. Since FSDP with DCP allows parallel writing, δ is significantly lower than in rank-0-only serialization. This allows for more frequent checkpoints (e.g., every 30 minutes instead of every 4 hours), drastically reducing the compute wasted during the inevitable hardware failures of long-running jobs.
The following chart visualizes the impact of parallel distributed checkpointing on the I/O overhead, enabling this higher frequency.
Comparison of checkpoint write latencies. As model size grows, aggregating to a single rank becomes a bottleneck, whereas DCP uses the aggregate bandwidth of the entire cluster.
By integrating torchrun with proper snapshot logic and the DCP API, you ensure that your training run is resilient. This resilience is not merely a convenience; for models requiring months of GPU time, it is the only way to guarantee convergence in an imperfect physical environment.
Was this section helpful?
© 2026 ApX Machine LearningEngineered with