Training runs for large language models often span weeks, making hardware reliability a significant variable. When training across dozens or hundreds of GPUs, the standard approach of aggregating a model's full state dictionary to a single rank for serialization is inefficient and often impossible due to host memory constraints. If the model size M exceeds the CPU memory of the coordinator node, a simple torch.save() call will trigger an Out of Memory (OOM) error.
This chapter focuses on persistence strategies designed for sharded architectures. We examine how to manage state dictionaries when parameters are partitioned across the cluster. You will learn to implement PyTorch's Distributed Checkpointing (DCP) API, which enables parallel I/O operations where each rank saves only its local shard of the model.
Total I/O Bandwidth≈N×Per-Rank Bandwidth
By utilizing all available storage controllers, we reduce the time spent in input/output blocking states. Furthermore, we address fault tolerance through TorchElastic. You will configure training scripts that detect worker failures, re-establish process groups, and resume automatically from the last valid snapshot without manual intervention.
5.1 Sharded vs Full State Dictionaries
5.2 PyTorch Distributed Checkpointing API
5.3 Elastic Training Integration
5.4 Practice: Implementing Resumable Training
© 2026 ApX Machine LearningEngineered with