FullyShardedDataParallel (FSDP) 与 DistributedDataParallel (DDP) 相比,为分布式训练提供了一种不同做法,不只简单的类替换。虽然它们的API表面看起来相似,但参数管理的核心机制有显著区别。在DDP中,模型包装器主要在桶级别处理梯度同步(all-reduce)。在FSDP中,包装器拥有参数,并在进程组中实际拆分张量。本节说明FSDP的编程实现,着重于ShardingStrategy配置和MixedPrecision策略,它们是在大规模训练中保持稳定所必需的。FSDP构造器接口分片操作的入口是FullyShardedDataParallel类。不同于DDP接受一个已在目标GPU上的模型,FSDP通常包装位于CPU上的模型,以避免初始化时立即出现内存不足(OOM)错误。构造器签名提供了对理论分析中提到的ZeRO阶段的控制。最重要的参数是sharding_strategy,它决定了模型状态如何划分。from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, BackwardPrefetch, ) # 基础包装结构 model = FSDP( module, sharding_strategy=ShardingStrategy.FULL_SHARD, mixed_precision=mixed_precision_policy, device_id=torch.cuda.current_device() )ZeRO阶段到分片策略的映射PyTorch将ZeRO优化的数学定义映射到ShardingStrategy枚举。选择正确的策略在于内存限制和通信带宽之间的平衡。FULL_SHARD (ZeRO 阶段 3): 这是默认行为。参数、梯度和优化器状态都被分片。参数只在计算特定层的正向和反向传播期间在GPU上具体化(聚集),然后立即释放。这提供了最大的内存节省($$1/N$$ 缩放),但由于频繁的AllGather操作而产生高通信开销。SHARD_GRAD_OP (ZeRO 阶段 2): 梯度和优化器状态被分片,但参数保持复制状态。此策略避免了正向传播期间聚集权重的通信开销,但需要足够的显存来保存完整的模型参数($$\Psi$$)。NO_SHARD (DDP 等效): 此模式在FSDP API中复制DDP的行为。它有助于调试,或者在缩减到少量GPU时使用,因为那时分片开销超过了其优势。以下图表说明了这些策略在双GPU设置中的内存分配区别。digraph G { rankdir=TB; node [shape=rect, style=filled, fontname="Arial", fontsize=10]; splines=ortho; bgcolor="transparent"; subgraph cluster_0 { label="标准DDP / NO_SHARD"; style=dashed; color="#adb5bd"; subgraph cluster_gpu0_ddp { label="GPU 0"; color="#dee2e6"; style=filled; node [width=1.5]; P0 [label="参数 (完整)", fillcolor="#a5d8ff"]; G0 [label="梯度 (完整)", fillcolor="#ffc9c9"]; O0 [label="优化器 (完整)", fillcolor="#b2f2bb"]; } subgraph cluster_gpu1_ddp { label="GPU 1"; color="#dee2e6"; style=filled; node [width=1.5]; P1 [label="参数 (完整)", fillcolor="#a5d8ff"]; G1 [label="梯度 (完整)", fillcolor="#ffc9c9"]; O1 [label="优化器 (完整)", fillcolor="#b2f2bb"]; } } subgraph cluster_1 { label="FSDP FULL_SHARD (ZeRO-3)"; style=dashed; color="#adb5bd"; subgraph cluster_gpu0_fsdp { label="GPU 0"; color="#dee2e6"; style=filled; node [width=1.5]; P0_s [label="参数 [分片 0]", fillcolor="#4dabf7"]; G0_s [label="梯度 [分片 0]", fillcolor="#ff8787"]; O0_s [label="优化器 [分片 0]", fillcolor="#69db7c"]; } subgraph cluster_gpu1_fsdp { label="GPU 1"; color="#dee2e6"; style=filled; node [width=1.5]; P1_s [label="参数 [分片 1]", fillcolor="#4dabf7"]; G1_s [label="梯度 [分片 1]", fillcolor="#ff8787"]; O1_s [label="优化器 [分片 1]", fillcolor="#69db7c"]; } } }标准复制和完全分片之间的内存分配比较。在FULL_SHARD中,更新的存储和计算都已分布,设备占用量随大小线性减少。混合精度配置不同于torch.cuda.amp通常需要外部缩放器和autocast上下文管理器,FSDP直接将混合精度整合到分片逻辑中。这种整合是必要的,原因在于FSDP在跨等级通信张量之前必须知道精度格式。如果通信采用FP32而计算采用BF16,则会浪费带宽。MixedPrecision配置类控制三种特定数据类型:param_dtype: 正向和反向计算期间用于模型参数的类型。reduce_dtype: 用于梯度归约(通信)的类型。buffer_dtype: 缓冲区(例如,批量归一化统计数据)的类型。对于Ampere或Hopper架构上的现代大型语言模型(LLM)训练,bfloat16是标准。它保留了FP32的指数范围,避免了FP16常见的下溢问题,通常无需损失缩放。import torch from torch.distributed.fsdp import MixedPrecision # 定义BFloat16训练策略 bf16_policy = MixedPrecision( param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, # 将通信量减少50% buffer_dtype=torch.bfloat16, )当param_dtype设置为bfloat16时,FSDP会保留FP32主权重(如果存在),但在正向传播之前将其转换为BFloat16。这与混合精度训练中的“主权重”思想一致,确保优化器步骤中的小权重更新不会因精度截断而丢失。扁平参数与正向传播当你用FSDP包装一个模块时,原始参数(例如model.layer1.weight)会被一个通用FlatParameter替换。这是一个单一的1D张量,将多个原始参数的存储视为一个整体。这种扁平化提高了内存访问模式和通信效率。FSDP没有为单个权重矩阵启动数百个小型NCCL内核,而是将它们聚合为更大的块。然而,这带来了一个限制:包装后你不能直接访问model.layer1.weight,因为该属性实际上不再以原始形式存在于设备上。访问它需要使用上下文管理器FSDP.summon_full_params(model),我们将在检查点部分进行说明。完整实现示例以下示例展示了一种初始化模式。它建立进程组,定义模型,并使用FULL_SHARD和BFloat16精度将其包装。import os import torch import torch.nn as nn import torch.distributed as dist from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, StateDictType, ) def setup(rank, world_size): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' # 使用NCCL后端初始化GPU的进程组 dist.init_process_group("nccl", rank=rank, world_size=world_size) torch.cuda.set_device(rank) def cleanup(): dist.destroy_process_group() def train_fsdp_step(rank, world_size): setup(rank, world_size) # 1. 定义模型 (标准PyTorch) # 在实际场景中,这很可能是一个Transformer model = nn.Sequential( nn.Linear(1024, 4096), nn.ReLU(), nn.Linear(4096, 1024) ) # 2. 定义策略 bf16_policy = MixedPrecision( param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.bfloat16, ) # 3. 用FSDP包装 # 注意: 我们通过device_id参数隐式地移动到设备。 # 对于大型模型,首先加载到'meta'设备或CPU。 fsdp_model = FSDP( model, sharding_strategy=ShardingStrategy.FULL_SHARD, mixed_precision=bf16_policy, device_id=torch.cuda.current_device(), ) # 4. 优化器初始化 # 重要: 优化器必须在包装后初始化 # 因为FSDP会修改参数结构。 optimizer = torch.optim.AdamW(fsdp_model.parameters(), lr=1e-3) # 5. 训练循环 # 在正确的设备上生成模拟数据 inputs = torch.randn(64, 1024).to(rank).bfloat16() optimizer.zero_grad() output = fsdp_model(inputs) loss = output.sum() loss.backward() optimizer.step() print(f"Rank {rank} step complete. Loss: {loss.item()}") cleanup() # 这个函数将在脚本中通过torch.multiprocessing启动 # mp.spawn(train_fsdp_step, args=(world_size,), nprocs=world_size)设备初始化与device_id在上面的示例中,device_id参数非常要紧。如果省略,FSDP可能会尝试在CPU或默认GPU上初始化分片,可能导致无声的性能下降或放置错误。当模型能适应单个GPU(如示例所示)时,包装一个已具体化的模型是可以接受的。然而,对于接近TB级别的模型,在分片前将完整模型具体化到CPU内存是不可能的。在这些情况下,我们采用meta设备进行延迟初始化,允许FSDP分片参数,而无需在任何单个设备上分配完整的模型状态。这种高级初始化模式是下一章的内容。