趋近智
虽然路由机制(例如 top-k 门控)会将输入令牌导向专家混合 (MoE) 层中的特定专家,但它们本身不保证计算负载在所有专家之间均匀分配。通常,如果没有明确干预,门控网络可能会偏好一小部分专家,从而导致严重的负载不均。这种不均带来了以下几个问题:
因此,在训练 MoE 模型时,实施促进负载均衡的机制是一种标准做法。最常见的方法是向主要任务损失(例如,语言建模的交叉熵损失)添加一个辅助损失项。此辅助损失会惩罚不平衡,指导门控网络更均匀地分配令牌。
辅助损失的目标是激励路由器为每个专家分配大致相同数量的令牌。Switch Transformer 论文和后续工作中介绍的一种被广泛采用的公式,旨在最小化每个专家处理的令牌数量的变异。
令 为专家数量, 为当前批次(或微批次)中的令牌数量。对于每个专家 ,我们可以定义两个量:
辅助负载均衡损失 通常计算为这两个向量 (vector)的点积,并乘以专家数量 和一个可调超参数 (parameter) (hyperparameter) :
用于反向传播 (backpropagation)的总损失是主要任务损失 和均衡损失之和:
直观理解: 最小化 促使所有专家的 和 都接近 。如果某个专家接收了大量令牌份额( 较高),损失就会增加。类似地,如果门控网络分配给某个专家高概率( 较高),损失也会增加。当实际分配 () 和路由器的置信度 () 均匀分布时,损失最小化。超参数 控制这种均衡激励相对于主要任务目标的强度;典型值通常很小(例如 0.01)。
下面是一个 PyTorch 代码片段,说明了计算方法,假设 gating_outputs 包含来自门控网络的概率,而 indices 包含每个令牌选择的专家索引:
import torch
import torch.nn.functional as F
# 示例输入(替换为实际模型输出)
# gating_outputs: 形状 [num_tokens, num_experts] - softmax 概率
# indices: 形状 [num_tokens, k] - 每个令牌选择的 top-k 专家索引
num_experts = 8
num_tokens = 1024
k = 2
gating_outputs = torch.randn(num_tokens, num_experts).softmax(dim=-1)
# 模拟 top-k 路由索引(实际中这些来自路由器)
indices = torch.topk(gating_outputs, k, dim=-1).indices
# --- 辅助损失计算 ---
# 计算 f_i: 路由到专家 i 的令牌分数
expert_mask = F.one_hot(indices, num_classes=num_experts).sum(dim=1)
# 形状 [num_tokens, num_experts],如果选择了专家则为 1,否则为 0
tokens_per_expert = expert_mask.sum(dim=0) # 形状 [num_experts]
f_i = tokens_per_expert / num_tokens # 每个专家的令牌分数
# 计算 P_i: 专家 i 的平均路由概率
P_i = gating_outputs.mean(dim=0) # 形状 [num_experts]
# 计算损失
# 注意: N = num_experts
load_balance_loss = num_experts * torch.sum(f_i * P_i)
# 示例: 添加到主要任务损失(假设 alpha = 0.01)
alpha = 0.01
# task_loss = ... (在其他地方计算)
# total_loss = task_loss + alpha * load_balance_loss
# total_loss.backward()
print(f"负载均衡损失项: {load_balance_loss.item():.4f}")
print(f"每个专家的令牌分布: {f_i.detach().numpy()}")
print(f"每个专家的平均概率: {P_i.detach().numpy()}")
示例可视化,比较专家间不均衡的令牌分布与理想的完全均衡状态。辅助损失旨在将分布推向均衡状态。
另一个常与辅助损失结合使用的机制是容量因子 ()。这限制了任何单个专家在一个批次内可以处理的令牌数量。每个专家的容量通常设置为:
容量因子 通常略大于 1(例如 1.25 或 1.5)。如果路由机制分配给某个专家的令牌数量超过其允许的容量,则超额令牌被视为“丢弃”或“溢出”。这些丢弃的令牌不参与该 MoE 层的计算(包括前向和反向传播 (backpropagation)),实际上像通过了一个恒等函数一样被处理。
虽然丢弃令牌可能看起来有害,但使用容量因子提供了对抗严重不平衡的硬性约束。它防止单个专家过载,即使辅助损失尚未完全纠正路由器的偏好。然而,将 设置得过低可能导致过多的令牌丢弃,从而阻碍学习。在强制均衡和保留所有信息之间存在权衡。在训练期间监控丢弃令牌的百分比对于调整 很重要。
负载均衡在分布式设置中尤为重要,尤其是在使用专家并行时,即不同专家位于不同的计算设备(例如 GPU)上。如果负载不均衡,承载受偏好专家的设备会成为瓶颈,而拥有未充分利用专家的设备则处于空闲状态,导致扩展性差和资源浪费。辅助损失和容量因子共同作用,以确保计算在分布式硬件上更均匀地分布。
总之,实现专家间负载的均衡分配对于高效且稳定的 MoE 模型训练来说,是必要条件。辅助负载均衡损失与仔细调整的容量因子相结合,提供了有效机制来鼓励门控网络更均匀地使用所有专家,从而最大化条件计算的益处。调整辅助损失系数 和容量因子 是成功训练大型 MoE 模型的重要方面。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造