趋近智
辅助负载均衡损失有助于防止专家崩溃,但出乎意料的是,其一个组成部分可能成为训练中出现显著不稳的原因。此组成部分常被称为路由器Z损失。了解其来源及如何处理,是成功训练大型MoE模型不可或缺的本领。
不稳源于门控网络产生的未经归一化的原始logits。回顾第一章可知,辅助损失包含一个旨在鼓励路由器使用多样专家集合的项。此项通常基于门控网络logits的平方和进行计算。
令 Laux 为辅助损失,且令 g(x)i 为给定输入token x 时,门控网络针对专家 i 产生的logit。Z损失分量 Lz 与这些logits的平方和成比例,并对批次中所有token进行平均。简化的表示为:
Lz=tokens∑(和(g(x)))2此损失的目的是让logits的数值保持较小,这间接促使专家上的softmax分布不那么尖锐,从而防止路由器过度自信,在训练早期将所有token路由到少数几个专家。
问题出现在当logits变得非常大时。因为此损失项是二次的,即使logit值适度增加,也可能导致 Lz 激增。如果发生这种情况,Z损失可能会压倒主要任务损失(例如交叉熵),反向传播巨大且无用的梯度通过门控网络。这可能导致整个训练过程不稳,使总损失飙升,模型表现下降。
控制Z损失最直接和常用的方式是将其与一个小系数进行缩放。此超参数通常称为 router_z_loss_coef 或类似名称,在将其加入总损失之前,乘以Z损失。
模型的总损失变为:
Ltotal=Ltask+α⋅Lbalance+β⋅Lz此处,β 是 router_z_loss_coef。通过将 β 设置为一个小值,通常在0.001到0.01之间,可以降低Z损失对总梯度的影响。
该系数的选择涉及一个权衡:
实际中,从 1e-3 这样的值开始是一种常用做法。监控训练日志中总损失的突然飙升,如果与辅助损失的飙升相对应,这是诊断该值是否需要调整的主要方式。下面的图表展示了路由器Z损失激增的典型不稳事件。
在第60步,路由器Z损失飙升,导致总损失相应跳跃。主要任务损失初期保持稳定,但如果在此不稳状态下继续训练,表现会下降。这是明确信号,表示应降低
router_z_loss_coef。
除了缩放损失,您还可以采用其他策略,通常是组合使用,以进一步提升稳定度。
门控网络的初始状态可以使模型易于出现不稳。如果门控网络中最终线性层的权重初始化过大,初始logits可能大到足以在训练的第一步就立即导致Z损失飙升。
一个简单且有效的技巧是,将此最终层的权重初始化为非常小的值,甚至为零。例如,使用一个标准差很小的截断正态分布(例如0.001)或对最终权重矩阵进行直接零初始化,可确保初始logits接近零。这使得专家选择接近均匀分布,让路由器能够逐渐学习其偏好,而不会导致初始损失激增。
另一种直接办法是,在logits用于计算Z损失之前,限制其数值。这是一个强力的预防措施,防止数值失控。您可以通过将logit张量钳制在预定义范围内来实现。
例如,在PyTorch中:
# 在您的MoE层中,计算logits之后
LOGIT_CAP = 30.0
# 仅为Z损失计算钳制logits
# 原始logits应用于softmax和路由
clamped_logits = torch.clamp(logits, -LOGIT_CAP, LOGIT_CAP)
# 现在使用clamped_logits计算Z损失
这确保了无论网络权重变得多大,任何单个logit对Z损失的贡献都受到限制。截断值的选择是另一个超参数,但它比损失系数不那么敏感。20到50之间的值通常足以防止最极端的数值问题。主要缺点是,如果限制过低,它可能使路由器的决策过程“饱和”,但其在这里的主要作用是作为稳定性的安全网。
通过结合合理的Z损失系数、细致的初始化以及可能的logit截断,您可以有效地控制路由器的行为,并创造必要的稳定条件,以训练即使是最大的专家混合模型。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造