将大型模型扩展到数百个GPU时,计算能力与通信开销之间会出现非线性关系。全局同步,通常依赖于NCCL原语(例如AllGather)来具体化参数,会产生一定的开销,当跨越通过商用以太网或超额订阅的InfiniBand连接的多个服务器机架时,这种开销可能变得无法承受。混合分片数据并行 (HSDP) 通过在FSDP的内存效率和分布式数据并行 (DDP) 的通信效率之间找到一个折中方案,从而应对这种吞吐量下降的问题。通过在高带宽区域(通常是带有NVLink的单个节点)内分片参数,并在较低带宽连接(节点间)上复制参数,HSDP优化了针对物理网络拓扑的训练循环。混合分片的架构在标准的FSDP配置中(常称为“完全分片”),模型状态被划分到整个WORLD_SIZE。如果您在128个GPU上训练一个100B参数的模型,每个GPU大约持有1/128的参数。为了执行前向传播,一个GPU必须获取剩余的127/128的数据,其中大部分通过节点间网络传输。HSDP改变了这种分区策略。它引入了两个不同的进程组:分片组(节点内): 参数在通过高速互连(例如NVLink)连接的GPU之间进行分片。复制组(节点间): 一个分片组的集合状态会在不同节点间复制。在前向和反向传播期间,AllGather操作只在分片组内进行。这使得大量带宽使用局限在本地NVLink网络中。节点间的通信被限制为梯度同步(类似于DDP),这在反向传播期间发生。下图展示了具有两个节点,每个节点包含四个GPU的HSDP设置中的数据流和状态分布。digraph G { rankdir=TB; compound=true; node [shape=rect, style=filled, fontname="Arial", fontsize=10]; edge [fontname="Arial", fontsize=9]; subgraph cluster_node1 { label="节点 1 (副本 A)"; style=filled; color="#f1f3f5"; node [fillcolor="#a5d8ff", color="#1c7ed6"]; GPU0 [label="GPU 0\n分片 1/4"]; GPU1 [label="GPU 1\n分片 2/4"]; GPU2 [label="GPU 2\n分片 3/4"]; GPU3 [label="GPU 3\n分片 4/4"]; {rank=same; GPU0; GPU1; GPU2; GPU3} GPU0 -> GPU1 [dir=both, color="#1c7ed6", label="高带宽 AllGather"]; GPU1 -> GPU2 [dir=both, color="#1c7ed6"]; GPU2 -> GPU3 [dir=both, color="#1c7ed6"]; } subgraph cluster_node2 { label="节点 2 (副本 B)"; style=filled; color="#f1f3f5"; node [fillcolor="#ffc9c9", color="#fa5252"]; GPU4 [label="GPU 4\n分片 1/4"]; GPU5 [label="GPU 5\n分片 2/4"]; GPU6 [label="GPU 6\n分片 3/4"]; GPU7 [label="GPU 7\n分片 4/4"]; {rank=same; GPU4; GPU5; GPU6; GPU7} GPU4 -> GPU5 [dir=both, color="#fa5252"]; GPU5 -> GPU6 [dir=both, color="#fa5252"]; GPU6 -> GPU7 [dir=both, color="#fa5252"]; } GPU0 -> GPU4 [dir=both, style=dashed, color="#868e96", label="低带宽梯度同步"]; GPU3 -> GPU7 [dir=both, style=dashed, color="#868e96"]; }HSDP将主要的参数收集操作隔离到高速节点内链接,而较慢的节点间链接仅用于梯度归约。通信量分析为了量化HSDP的优势,我们分析了每个训练步骤所需的通信量。设$\Psi$为模型参数量,$N$为GPU总数,$S$为分片组的大小(通常是每个节点的GPU数量)。在标准FSDP中,前向传播期间每个GPU的AllGather通信量$V_{FSDP}$为:$$ V_{FSDP} = \frac{N-1}{N} \cdot \Psi \cdot \text{每个参数的字节数} $$当$N \to \infty$时,每个GPU必须下载几乎整个模型$\Psi$。在HSDP中,AllGather仅限于分片组$S$。通信量$V_{HSDP}$变为:$$ V_{HSDP} = \frac{S-1}{S} \cdot \Psi \cdot \text{每个参数的字节数} $$如果我们每个节点有8个GPU($S=8$),并且集群有16个节点($N=128$),标准FSDP需要通过聚合网络检索$\frac{127}{128}\Psi$。HSDP仅需通过本地NVLink检索$\frac{7}{8}\Psi$。尽管通信量$\frac{7}{8}\Psi$略低于$\frac{127}{128}\Psi$,但差异在于该传输可用的带宽($B$)。本地互连通常提供600-900 GB/s的带宽,而节点间以太网可能仅提供25-50 GB/s。延迟减少量$L_{gain}$可以通过比较传输时间来近似表示:$$ T_{FSDP} \approx \frac{\Psi}{B_{\text{节点间}}} \quad \text{对比} \quad T_{HSDP} \approx \frac{\Psi}{B_{\text{节点内}}} + T_{\text{梯度同步}} $$内存权衡HSDP的限制在于内存容量。通过跨节点复制模型,您将失去完全FSDP的全局内存聚合能力。完全FSDP容量: 总内存 = $N \times \text{VRAM}_{GPU}$HSDP容量: 总内存 = $S \times \text{VRAM}_{GPU}$如果您的模型适合单个节点的总VRAM(例如,8x80GB = 640GB),HSDP允许您扩展到数千个GPU,而不会让互连成为瓶颈。如果模型超过单个节点的容量,则必须使用完全FSDP或模型并行,而不论网络开销如何。配置设备网格在现代PyTorch(2.x+版本)中实现HSDP,依赖于DeviceMesh抽象。设备网格允许您将GPU集群视为一个多维网格。对于HSDP,我们构建一个二维网格:一个轴用于复制(节点间),一个轴用于分片(节点内)。以下代码演示了为具有4个节点、每个节点8个GPU的集群初始化设备网格。import torch import torch.distributed as dist from torch.distributed.device_mesh import init_device_mesh def setup_hsdp_mesh(): # 假设已完成标准分布式初始化 # world_size = 32, rank = 0..31 # 定义拓扑: # replicate_on: 节点组(节点间) # shard_on: 节点内的GPU(节点内) # 形状 (4, 8) 表示 4 个副本,每个副本在 8 个设备上分片 mesh_2d = init_device_mesh( "cuda", (4, 8), mesh_dim_names=("replicate", "shard") ) return mesh_2d # 在FSDP中的使用 # 封装模型时,传入device_mesh # fsdp_model = FSDP(model, device_mesh=mesh_2d, ...)当device_mesh被传入FSDP构造函数时,PyTorch会自动检测二维结构。它沿“分片”维度应用ShardingStrategy.SHARD_GRAD_OP (ZeRO-2) 或 FULL_SHARD (ZeRO-3),并沿“复制”维度应用NO_SHARD (复制)。实现策略通过sharding_strategy参数可获得两种主要的混合分片方式,不过对于精细调优,显式设备网格控制更受推荐。HYBRID_SHARD:这等同于节点内的ZeRO-3。参数、梯度和优化器状态在节点内完全分片。跨节点时,这些分片会被复制。_HYBRID_SHARD_ZERO2:这在节点内应用ZeRO-2。参数不进行分片(在前向传播后完整保留或一次性收集),而梯度和优化器状态则进行分片。这会消耗更多的内存,但进一步减少了通信。选择正确的策略依据模型的算术强度和特定瓶颈(计算密集型、内存密集型或IO密集型)。下图展示了吞吐量扩展效率。请留意,当节点数量增加时,标准FSDP在低带宽以太网集群上的扩展性能如何下降,而HSDP通过将大量流量本地化来保持接近线性的扩展。{ "layout": { "title": "扩展效率:100Gbps 以太网上的 FSDP 对比 HSDP", "xaxis": { "title": "节点数量 (每个节点 8 个 GPU)", "tickvals": [1, 2, 4, 8, 16], "showgrid": true, "gridcolor": "#e9ecef" }, "yaxis": { "title": "每个 GPU 的吞吐量 (TFLOPS)", "range": [0, 200], "showgrid": true, "gridcolor": "#e9ecef" }, "plot_bgcolor": "white", "width": 600, "height": 400, "legend": {"x": 0.7, "y": 1} }, "data": [ { "x": [1, 2, 4, 8, 16], "y": [180, 175, 160, 130, 95], "type": "scatter", "mode": "lines+markers", "name": "标准 FSDP", "line": {"color": "#fa5252", "width": 3} }, { "x": [1, 2, 4, 8, 16], "y": [180, 178, 176, 174, 172], "type": "scatter", "mode": "lines+markers", "name": "混合分片 (HSDP)", "line": {"color": "#228be6", "width": 3} } ] }当网络带宽饱和时,标准FSDP会出现性能下降,而HSDP通过发挥本地带宽的优势来保持吞吐量。生产环境的实用说明实现HSDP需要仔细关注集群的同构性。由于HSDP依赖于复制“分片组”的状态,因此每个分片组必须相同。您无法在同一个HSDP网格中轻易混用具有8个GPU的节点和具有4个GPU的节点,因为分片分区在数学上将无法对齐。此外,在保存检查点时,HSDP提供了一个独特的优势。由于模型在单个节点内完整存在(以分片形式),您可以配置检查点逻辑,使其仅从一个复制组(例如,复制维度的rank 0)保存状态。这与集群中每个GPU同时写入数据相比,减少了存储系统的I/O压力。如果您不使用DeviceMesh API,使用ShardingStrategy.HYBRID_SHARD需要显式处理进程组。您需要手动为节点内通信创建一个进程组,并将其传入FSDP构造函数:# 传统方法(DeviceMesh之前) import torch.distributed as dist from torch.distributed.fsdp import ShardingStrategy # 创建节点内进程组 node_pg = dist.new_group(ranks_in_this_node) inter_node_pg = dist.new_group(ranks_across_nodes) model = FSDP( model, process_group=(node_pg, inter_node_pg), sharding_strategy=ShardingStrategy.HYBRID_SHARD, # ... )DeviceMesh方法强烈推荐用于新的实现,因为它抽象了rank计算的复杂性,并确保与其他分布式功能(如张量并行)的兼容性。