趋近智
大型语言模型的训练通常持续数周,因此硬件可靠性是一个主要考量。当使用数十或数百个GPU进行训练时,将模型的完整状态字典汇集到单个进程进行序列化的传统做法效率不高,并且由于主机内存限制常常无法实现。如果模型大小M超出协调器节点的CPU内存,简单的torch.save()调用将引发内存不足(OOM)错误。
本章侧重于专为分片架构设计的持久化方案。我们考察当参数在集群中被分区时如何管理状态字典。你将学会实现PyTorch的分布式检查点(DCP)API,它支持并行I/O操作,每个进程只保存其本地的模型分片。
Total I/O Bandwidth≈N×Per-Rank Bandwidth
通过使用所有可用的存储控制器,我们减少了在输入/输出阻塞状态中花费的时间。此外,我们通过TorchElastic处理容错问题。你将配置训练脚本,使其能够检测工作进程故障、重新建立进程组,并无需手动干预地从最后一个有效的快照自动恢复。
5.1 分片状态字典与完整状态字典
5.2 PyTorch 分布式检查点 API
5.3 弹性训练集成
5.4 实践:实现可恢复训练
© 2026 ApX Machine Learning用心打造