趋近智
FullyShardedDataParallel (FSDP) 与 DistributedDataParallel (DDP) 相比,为分布式训练提供了一种不同做法,不只简单的类替换。虽然它们的API表面看起来相似,但参数管理的核心机制有显著区别。在DDP中,模型包装器主要在桶级别处理梯度同步(all-reduce)。在FSDP中,包装器拥有参数,并在进程组中实际拆分张量。
本节说明FSDP的编程实现,着重于ShardingStrategy配置和MixedPrecision策略,它们是在大规模训练中保持稳定所必需的。
分片操作的入口是FullyShardedDataParallel类。不同于DDP接受一个已在目标GPU上的模型,FSDP通常包装位于CPU上的模型,以避免初始化时立即出现内存不足(OOM)错误。
构造器签名提供了对理论分析中提到的ZeRO阶段的控制。最重要的参数是sharding_strategy,它决定了模型状态如何划分。
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
ShardingStrategy,
MixedPrecision,
BackwardPrefetch,
)
# 基础包装结构
model = FSDP(
module,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision_policy,
device_id=torch.cuda.current_device()
)
PyTorch将ZeRO优化的数学定义映射到ShardingStrategy枚举。选择正确的策略在于内存限制和通信带宽之间的平衡。
FULL_SHARD (ZeRO 阶段 3): 这是默认行为。参数、梯度和优化器状态都被分片。参数只在计算特定层的正向和反向传播期间在GPU上具体化(聚集),然后立即释放。这提供了最大的内存节省(1/N 缩放),但由于频繁的AllGather操作而产生高通信开销。
SHARD_GRAD_OP (ZeRO 阶段 2): 梯度和优化器状态被分片,但参数保持复制状态。此策略避免了正向传播期间聚集权重的通信开销,但需要足够的显存来保存完整的模型参数(Ψ)。
NO_SHARD (DDP 等效): 此模式在FSDP API中复制DDP的行为。它有助于调试,或者在缩减到少量GPU时使用,因为那时分片开销超过了其优势。
以下图表说明了这些策略在双GPU设置中的内存分配区别。
标准复制和完全分片之间的内存分配比较。在
FULL_SHARD中,更新的存储和计算都已分布,设备占用量随大小线性减少。
不同于torch.cuda.amp通常需要外部缩放器和autocast上下文管理器,FSDP直接将混合精度整合到分片逻辑中。这种整合是必要的,原因在于FSDP在跨等级通信张量之前必须知道精度格式。如果通信采用FP32而计算采用BF16,则会浪费带宽。
MixedPrecision配置类控制三种特定数据类型:
param_dtype: 正向和反向计算期间用于模型参数的类型。reduce_dtype: 用于梯度归约(通信)的类型。buffer_dtype: 缓冲区(例如,批量归一化统计数据)的类型。对于Ampere或Hopper架构上的现代大型语言模型(LLM)训练,bfloat16是标准。它保留了FP32的指数范围,避免了FP16常见的下溢问题,通常无需损失缩放。
import torch
from torch.distributed.fsdp import MixedPrecision
# 定义BFloat16训练策略
bf16_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16, # 将通信量减少50%
buffer_dtype=torch.bfloat16,
)
当param_dtype设置为bfloat16时,FSDP会保留FP32主权重(如果存在),但在正向传播之前将其转换为BFloat16。这与混合精度训练中的“主权重”思想一致,确保优化器步骤中的小权重更新不会因精度截断而丢失。
当你用FSDP包装一个模块时,原始参数(例如model.layer1.weight)会被一个通用FlatParameter替换。这是一个单一的1D张量,将多个原始参数的存储视为一个整体。
这种扁平化提高了内存访问模式和通信效率。FSDP没有为单个权重矩阵启动数百个小型NCCL内核,而是将它们聚合为更大的块。
然而,这带来了一个限制:包装后你不能直接访问model.layer1.weight,因为该属性实际上不再以原始形式存在于设备上。访问它需要使用上下文管理器FSDP.summon_full_params(model),我们将在检查点部分进行说明。
以下示例展示了一种初始化模式。它建立进程组,定义模型,并使用FULL_SHARD和BFloat16精度将其包装。
import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
ShardingStrategy,
MixedPrecision,
StateDictType,
)
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# 使用NCCL后端初始化GPU的进程组
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def train_fsdp_step(rank, world_size):
setup(rank, world_size)
# 1. 定义模型 (标准PyTorch)
# 在实际场景中,这很可能是一个Transformer
model = nn.Sequential(
nn.Linear(1024, 4096),
nn.ReLU(),
nn.Linear(4096, 1024)
)
# 2. 定义策略
bf16_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
)
# 3. 用FSDP包装
# 注意: 我们通过device_id参数隐式地移动到设备。
# 对于大型模型,首先加载到'meta'设备或CPU。
fsdp_model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=bf16_policy,
device_id=torch.cuda.current_device(),
)
# 4. 优化器初始化
# 重要: 优化器必须在包装后初始化
# 因为FSDP会修改参数结构。
optimizer = torch.optim.AdamW(fsdp_model.parameters(), lr=1e-3)
# 5. 训练循环
# 在正确的设备上生成模拟数据
inputs = torch.randn(64, 1024).to(rank).bfloat16()
optimizer.zero_grad()
output = fsdp_model(inputs)
loss = output.sum()
loss.backward()
optimizer.step()
print(f"Rank {rank} step complete. Loss: {loss.item()}")
cleanup()
# 这个函数将在脚本中通过torch.multiprocessing启动
# mp.spawn(train_fsdp_step, args=(world_size,), nprocs=world_size)
device_id在上面的示例中,device_id参数非常要紧。如果省略,FSDP可能会尝试在CPU或默认GPU上初始化分片,可能导致无声的性能下降或放置错误。
当模型能适应单个GPU(如示例所示)时,包装一个已具体化的模型是可以接受的。然而,对于接近TB级别的模型,在分片前将完整模型具体化到CPU内存是不可能的。在这些情况下,我们采用meta设备进行延迟初始化,允许FSDP分片参数,而无需在任何单个设备上分配完整的模型状态。这种高级初始化模式是下一章的内容。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造