趋近智
将大型模型扩展到数百个GPU时,计算能力与通信开销之间会出现非线性关系。全局同步,通常依赖于NCCL原语(例如AllGather)来具体化参数,会产生一定的开销,当跨越通过商用以太网或超额订阅的InfiniBand连接的多个服务器机架时,这种开销可能变得无法承受。
混合分片数据并行 (HSDP) 通过在FSDP的内存效率和分布式数据并行 (DDP) 的通信效率之间找到一个折中方案,从而应对这种吞吐量下降的问题。通过在高带宽区域(通常是带有NVLink的单个节点)内分片参数,并在较低带宽连接(节点间)上复制参数,HSDP优化了针对物理网络拓扑的训练循环。
在标准的FSDP配置中(常称为“完全分片”),模型状态被划分到整个WORLD_SIZE。如果您在128个GPU上训练一个100B参数的模型,每个GPU大约持有1/128的参数。为了执行前向传播,一个GPU必须获取剩余的127/128的数据,其中大部分通过节点间网络传输。
HSDP改变了这种分区策略。它引入了两个不同的进程组:
在前向和反向传播期间,AllGather操作只在分片组内进行。这使得大量带宽使用局限在本地NVLink网络中。节点间的通信被限制为梯度同步(类似于DDP),这在反向传播期间发生。
下图展示了具有两个节点,每个节点包含四个GPU的HSDP设置中的数据流和状态分布。
HSDP将主要的参数收集操作隔离到高速节点内链接,而较慢的节点间链接仅用于梯度归约。
为了量化HSDP的优势,我们分析了每个训练步骤所需的通信量。设Ψ为模型参数量,N为GPU总数,S为分片组的大小(通常是每个节点的GPU数量)。
在标准FSDP中,前向传播期间每个GPU的AllGather通信量VFSDP为:
VFSDP=NN−1⋅Ψ⋅每个参数的字节数
当N→∞时,每个GPU必须下载几乎整个模型Ψ。
在HSDP中,AllGather仅限于分片组S。通信量VHSDP变为:
VHSDP=SS−1⋅Ψ⋅每个参数的字节数
如果我们每个节点有8个GPU(S=8),并且集群有16个节点(N=128),标准FSDP需要通过聚合网络检索128127Ψ。HSDP仅需通过本地NVLink检索87Ψ。尽管通信量87Ψ略低于128127Ψ,但差异在于该传输可用的带宽(B)。本地互连通常提供600-900 GB/s的带宽,而节点间以太网可能仅提供25-50 GB/s。
延迟减少量Lgain可以通过比较传输时间来近似表示:
TFSDP≈B节点间Ψ对比THSDP≈B节点内Ψ+T梯度同步
HSDP的限制在于内存容量。通过跨节点复制模型,您将失去完全FSDP的全局内存聚合能力。
如果您的模型适合单个节点的总VRAM(例如,8x80GB = 640GB),HSDP允许您扩展到数千个GPU,而不会让互连成为瓶颈。如果模型超过单个节点的容量,则必须使用完全FSDP或模型并行,而不论网络开销如何。
在现代PyTorch(2.x+版本)中实现HSDP,依赖于DeviceMesh抽象。设备网格允许您将GPU集群视为一个多维网格。对于HSDP,我们构建一个二维网格:一个轴用于复制(节点间),一个轴用于分片(节点内)。
以下代码演示了为具有4个节点、每个节点8个GPU的集群初始化设备网格。
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
def setup_hsdp_mesh():
# 假设已完成标准分布式初始化
# world_size = 32, rank = 0..31
# 定义拓扑:
# replicate_on: 节点组(节点间)
# shard_on: 节点内的GPU(节点内)
# 形状 (4, 8) 表示 4 个副本,每个副本在 8 个设备上分片
mesh_2d = init_device_mesh(
"cuda",
(4, 8),
mesh_dim_names=("replicate", "shard")
)
return mesh_2d
# 在FSDP中的使用
# 封装模型时,传入device_mesh
# fsdp_model = FSDP(model, device_mesh=mesh_2d, ...)
当device_mesh被传入FSDP构造函数时,PyTorch会自动检测二维结构。它沿“分片”维度应用ShardingStrategy.SHARD_GRAD_OP (ZeRO-2) 或 FULL_SHARD (ZeRO-3),并沿“复制”维度应用NO_SHARD (复制)。
通过sharding_strategy参数可获得两种主要的混合分片方式,不过对于精细调优,显式设备网格控制更受推荐。
HYBRID_SHARD:这等同于节点内的ZeRO-3。参数、梯度和优化器状态在节点内完全分片。跨节点时,这些分片会被复制。_HYBRID_SHARD_ZERO2:这在节点内应用ZeRO-2。参数不进行分片(在前向传播后完整保留或一次性收集),而梯度和优化器状态则进行分片。这会消耗更多的内存,但进一步减少了通信。选择正确的策略依据模型的算术强度和特定瓶颈(计算密集型、内存密集型或IO密集型)。
下图展示了吞吐量扩展效率。请留意,当节点数量增加时,标准FSDP在低带宽以太网集群上的扩展性能如何下降,而HSDP通过将大量流量本地化来保持接近线性的扩展。
当网络带宽饱和时,标准FSDP会出现性能下降,而HSDP通过发挥本地带宽的优势来保持吞吐量。
实现HSDP需要仔细关注集群的同构性。由于HSDP依赖于复制“分片组”的状态,因此每个分片组必须相同。您无法在同一个HSDP网格中轻易混用具有8个GPU的节点和具有4个GPU的节点,因为分片分区在数学上将无法对齐。
此外,在保存检查点时,HSDP提供了一个独特的优势。由于模型在单个节点内完整存在(以分片形式),您可以配置检查点逻辑,使其仅从一个复制组(例如,复制维度的rank 0)保存状态。这与集群中每个GPU同时写入数据相比,减少了存储系统的I/O压力。
如果您不使用DeviceMesh API,使用ShardingStrategy.HYBRID_SHARD需要显式处理进程组。您需要手动为节点内通信创建一个进程组,并将其传入FSDP构造函数:
# 传统方法(DeviceMesh之前)
import torch.distributed as dist
from torch.distributed.fsdp import ShardingStrategy
# 创建节点内进程组
node_pg = dist.new_group(ranks_in_this_node)
inter_node_pg = dist.new_group(ranks_across_nodes)
model = FSDP(
model,
process_group=(node_pg, inter_node_pg),
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
# ...
)
DeviceMesh方法强烈推荐用于新的实现,因为它抽象了rank计算的复杂性,并确保与其他分布式功能(如张量并行)的兼容性。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造