趋近智
在分布式环境中管理状态字典本质上是一个资源管理难题。当训练包含数十亿参数的模型时,传统的 PyTorch 检查点工作流程从一个例行操作转变为主要的系统瓶颈。将模型权重整合到单个设备上进行序列化的标准做法,在模型大小超出任何单个硬件单元的内存容量时就会失效。
在非分布式设置中,model.state_dict() 返回一个将参数名映射到张量的字典。在完全分片数据并行(FSDP)中,物理参数在不同的进程(rank)之间进行划分。访问完整状态字典需要通过网络收集这些分片以重建全局张量。此操作会引入两种不同的故障模式:AllGather 阶段的网络饱和,以及负责序列化的协调进程(coordinator rank)上的主机内存耗尽。
当你在 FSDP 模型上调用标准的状态字典检索时,系统会默认聚合所有参数。如果你正在训练一个包含 700 亿参数的混合精度模型(FP32 主权重),仅模型本身就占用约 280 GB 内存。
要使用传统方法保存此模型,进程 0 必须为整个 280 GB 结构以及序列化缓冲区分配内存。GPU 集群中的大多数服务器级 CPU 配备 512 GB 到 1 TB 的 RAM,但这些内存由所有本地进程共享。如果一个节点承载 8 个 GPU,并且每个进程都需要一个用于聚合的 CPU 缓冲区,主机内存将迅速超额使用。
此操作的计算开销与模型大小呈线性关系,但与收集阶段的网络带宽可用性成反比。使用单个协调器对大小为 S 的模型进行检查点操作所需的时间 Tsave 为:
Tsave≈BnetS+BdiskS
其中 Bnet 是有效的互连带宽,而 Bdisk 是存储控制器的写入速度。此公式忽略了同步屏障引入的显著延迟,在该屏障中,所有 GPU 工作进程必须暂停计算并等待收集过程完成。
为解决这些瓶颈,FSDP 提供了替代的序列化策略,这些策略适应数据的分布式特性。与其重建全局张量,我们可以直接保存本地分片。这种方法从“先收集后写入”转变为“并行写入”。
PyTorch 通过 StateDictType 枚举公开了这些策略。理解这些类型之间的区别对于实现有效的容错功能是必需的。
FULL_STATE_DICT)这是前面描述的默认行为。它重建未分片的模型。虽然在训练期间计算开销大且内存占用高,但它创建一个可移植的检查点。你可以将完整状态字典加载到单个 CPU 上的模型或不同集群拓扑结构中,而无需复杂的转换逻辑。通常建议仅在训练结束时执行一次此聚合以进行推理导出,而不是在中间检查点期间。
SHARDED_STATE_DICT)此策略保存参数时,按照它们在 FSDP 包装器中逻辑存在的方式,但将它们按进程(rank)分片。每个 GPU 仅保存其当前拥有的数据分片。这会产生 N 个更小的文件(或并行写入共享对象存储),而不是一个单一的巨大文件。
优点显而易见:
吞吐量分片≈min(N×B磁盘,B存储后端)
此方法保持参数的逻辑映射,允许将检查点重新加载到具有不同数量 GPU 的集群中(重新分配),前提是底层拓扑能够支持分片的重新分配。
LOCAL_STATE_DICT)这表示 GPU 内存中原始、扁平化的存储,与其在内存中的状态完全一致,通常包含 FSDP 为对齐而添加的内部填充。这是最快的方法,因为它不执行任何处理或元数据管理。然而,它与特定的集群拓扑结构紧密关联。使用 LOCAL_STATE_DICT 在 32 个 GPU 上保存的检查点,除非进行大量的手动张量操作,否则无法轻易加载到 64 个 GPU 上。这很少用于生产检查点,但可用于临时调试快照。
下图对比了完整状态聚合与分片持久化之间的数据流。请注意聚合场景中进程 0 上的瓶颈,与分片方法的分布式吞吐量相比。
检查点拓扑结构比较。传统方法在协调进程上造成通信和内存瓶颈。分片方法将 I/O 负载分布到所有参与进程,实现写入吞吐量的线性扩展。
为强制使用特定的状态字典类型,PyTorch 提供了 FSDP.state_dict_type 上下文管理器。此上下文必须包裹 state_dict() 调用,以改变参数的收集方式(或不收集)。
使用 SHARDED_STATE_DICT 时,返回的字典包含 ShardedTensor 对象,而不是标准的 torch.Tensor 对象。这些在序列化期间需要特殊处理。分布式检查点 API (torch.distributed.checkpoint) 被设计为原生处理这些分片张量。
import torch
import torch.distributed.checkpoint as dcp
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
def save_checkpoint(model, rank, checkpoint_path):
# 配置模型以返回参数的分片视图
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
# 获取包含 ShardedTensors 的状态字典
state_dict = model.state_dict()
# 使用分布式检查点 API 进行并行 I/O
# dcp.save 处理来自多个进程写入分片的复杂性
dcp.save(
state_dict=state_dict,
checkpoint_id=checkpoint_path,
)
# 加载需要相同的上下文来正确映射分片
def load_checkpoint(model, checkpoint_path):
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
state_dict = model.state_dict()
# 就地加载到预分片模型中
dcp.load(
state_dict=state_dict,
checkpoint_id=checkpoint_path,
)
model.load_state_dict(state_dict)
此实现确保没有单个张量超出分配给 GPU 或主机 CPU 进程的内存。dcp 模块处理底层的异步 I/O,使训练循环能够尽快恢复计算。
转向分片检查点显著减少了阻塞 I/O 状态中花费的时间。在对超过 100 亿参数的模型进行的实验中,差异呈指数级增长。下图说明了随着模型大小增加,保存检查点所需的时间,比较了传统聚合方法与分片保存方法。
检查点操作的延迟比较。完整状态方法扩展性不佳,最终导致超时或 OOM 错误(此处模拟为指数增长)。分片状态方法保持与模型大小接近线性的性能,利用并行带宽。
通过采用 SHARDED_STATE_DICT 和分布式检查点 API,你可以将检查点开销与单节点限制解耦。这是训练太字节规模模型的强制性架构模式,确保系统保持弹性且高效,即使参数数量增长到数千亿。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造