训练专家混合模型(Mixture of Experts)需特别注意计算如何在专家间分配。若不加限制,路由机制易产生偏差,导致部分专家长期未充分利用(“饥饿”),而另一些则负载过重。这种不平衡抵消了条件计算带来的效率提升,并可能严重影响模型的有效学习能力。辅助损失函数提供了一种直接方法来对抗这种趋势,通过在整体训练目标中增加惩罚项,明确促使专家得到更均匀的使用。
组合损失函数通常表现为:
L总=L任务+αL辅助
在此,L任务 是主要目标的标准损失函数(例如,分类的交叉熵损失,语言建模损失),L辅助 是为促进平衡而设计的辅助损失项,而 α 是一个标量超参数,它控制平衡目标与任务目标之间的相对重要性。
辅助损失的必要性
若无辅助损失,路由器的学习仅由最小化 L任务 驱动。这会形成正反馈循环,导致那些最初在某些输入上表现稍好的专家获得更多令牌,使它们能进一步专业化并吸引更多令牌,最终导致严重不平衡。导致此情况的因素包括:
- 初始化: 随机初始化最初可能偏向某些专家。
- 数据分布: 训练数据的不均匀性可能自然地与特定专家更匹配。
- 容量限制: 若强制执行专家容量(限制每个专家的令牌数),持续过载的专家可能丢弃令牌,间接指示路由器避免选择它们,从而可能使其他有能力的专家“饥饿”。
辅助损失通过引入明确的优化压力来平衡负载,从而打破这些循环。
辅助损失的常见形式
已提出几种 L辅助 的形式,主要侧重于分配给专家的令牌分布,或门控网络产生的概率分布。
1. 负载均衡损失(基于令牌分布)
这可能是最常用的方法,直接惩罚批次中每个专家处理的令牌数量不平衡。设 N 为专家数量。对于给定批次的 T 个令牌:
- 让 h(xt) 表示令牌 t 的门控网络输出(路由器概率)。h(xt)i 是分配给专家 i 的概率。
- 将 fi 定义为批次中路由到专家 i 的令牌比例。若使用 top-1 路由,这简单地是专家 i 具有最高概率的令牌比例。若使用 top-k 路由,这是专家 i 被选为 top-k 专家之一的令牌比例。
- 将 Pi 定义为路由器在批次中所有令牌上分配给专家 i 的平均概率:Pi=T1∑t=1Th(xt)i。
一种广泛使用的负载均衡损失,源自 Switch Transformer 等工作,其形式如下:
L负载=N⋅∑i=1NfiPi
思想: 这种损失促使路由器更均匀地分配令牌。它计算分配给每个专家的令牌比例 (fi) 与该专家的平均路由器概率 (Pi) 之间的点积。最小化此项可避免专家获得大量令牌(fi 高)同时也被分配高平均概率(Pi 高)的情况。实际上,它惩罚路由器将分配频率和概率质量集中到少数专家身上。缩放因子 N 使损失幅度与专家数量保持一致。此损失需要根据当前路由决策按批次计算。
2. 路由器概率方差损失(CV 平方)
另一种方法侧重于令牌分配前路由器概率的方差。目标是促使门控网络在批次中平均为所有专家输出相似的概率。
- 计算 Pi,即如上所述分配给专家 i 的平均概率。
- 计算这些平均概率的均值:Pˉ=N1∑i=1NPi。注意 Pˉ=1/N,因为每个令牌的概率 h(xt) 之和为 1。
- 计算这些平均概率的方差:Var(P)=N1∑i=1N(Pi−Pˉ)2=N1∑i=1N(Pi−1/N)2。
变异系数平方 (CV2) 损失为:
Lcv=Pˉ2Var(P)=N2⋅Var(P)=N∑i=1N(Pi−1/N)2
思想: 这种损失直接衡量路由器分配的平均概率的不平衡程度。最小化 Lcv 会促使每个专家的平均概率 Pi 趋近于理想的均匀值 1/N,从而鼓励路由器平均更平等地对待所有专家。它侧重于路由器的输出分布,而非最终的令牌计数。
8 个专家每批次令牌分配的示例分布,比较了无辅助损失(不平衡)与应用 L负载 或 Lcv 后(使用更均匀)的情况。
3. 路由器 Logit 正则化(例如,Z-损失)
某些方法直接对门控网络在 Softmax 激活前产生的 logit 进行正则化。“路由器 Z-损失”就是一个例子,其目的是控制这些 logit 的幅度。一种简化形式可能会惩罚每个令牌 logit 的平方和:
Lz∝∑t=1T∑i=1N(logitt,i)2
思想: 大的 logit 值在 Softmax 后会导致尖锐、高置信度的概率分布。惩罚大的 logit 鼓励路由器产生更柔和的概率,尤其是在训练早期。这可以防止路由器过早地将所有令牌分配给一个或少数几个专家,从而提高训练稳定性,并可能在专业化发生前有助于泛化学习。
平衡系数 α 的调整
超参数 α 很重要。它调节着优化主要任务与强制专家负载均衡之间的权衡。
- 小 α: 平衡损失影响不大,专家负载不平衡可能持续存在。
- 大 α: 平衡目标可能主导任务损失,可能减缓甚至阻碍主要任务的收敛。模型可能优先考虑平衡而非学习有用的表示。
确定 α 的合适值通常是经验性的。常见做法包括:
- 典型范围: 值通常较小,例如 10−2 或 10−3,但最优值很大程度上取决于具体的模型架构、任务、专家数量以及所选的辅助损失形式。
- 监控: 在训练期间观察专家使用指标。跟踪每个专家在每个训练步骤或周期处理的令牌数量。TensorBoard 或 Weights & Biases 等工具在此非常有帮助。同时监控 L辅助 相对于 L任务 的幅度。
- 迭代调整: 从小 α 开始。若观察到明显不平衡(例如,某些专家获得的令牌持续远低于平均水平),则逐渐增加 α。确保增加 α 不会导致主要任务性能(验证损失/准确率)显著下降。
- 调度: 一些实践者在训练期间对 α 进行调度,例如,可能从稍高的值开始以在早期强制平衡,然后随着训练进行和专家自然专业化而逐渐降低。
与 Top-k 路由的配合
使用 top-k 路由(其中 k>1)时,辅助损失通常基于门控概率计算,在进行 top-k 选择之前。例如,L负载 仍会使用 Pi(分配给所有令牌中专家 i 的平均概率)和 fi(专家 i 被选为 top-k 之一的令牌比例)。即使每个令牌激活多个专家,损失仍致力于平衡底层概率分布。
实际考量
- 实现: 辅助损失在前向传播中计算,通常在门控网络计算完概率之后,但在专家计算之前(尽管 L负载 中的 fi 可能需要知道最终分配结果)。然后,在反向传播之前,将其添加到任务损失中。
- 梯度流: 确保来自 L辅助 的梯度流回到门控网络的参数。
- 影响: 引入 L辅助 可能会稍微改变初始训练动态。尽管它最初可能看起来会减缓 L任务 的收敛,但改进的稳定性和专家使用效率通常会带来更好的最终模型性能和更有效的参数利用。
通过仔细选择和调整辅助损失函数,可以缓解 MoE 训练中固有的负载均衡难题,为稳定学习和高效专家专业化提供途径。后续章节将分析路由器的其他优化策略,并讨论令牌丢弃和专业化崩溃等问题。