虽然像 DPM-Solver 这样的优化的采样器可以减少生成所需的步数,但基础扩散模型通常仍然庞大且计算密集。模型蒸馏提供了一种补充方法,通过创建更小、更快的“学生”模型来模拟较大、预训练 (pre-training)的“教师”扩散模型的行为,从而加速推理 (inference)。这项技术沿用了更广义的深度学习 (deep learning)实践,即将知识从复杂模型迁移到更高效的模型中。
蒸馏技术在扩散模型中的主要目标是训练一个学生网络 θstudent,使其模仿一个固定的、预训练的教师网络 θteacher 所做的预测。学生模型通常具有明显更小的架构(例如,更少的层、通道或注意力头),使其在推理时更快、内存占用更少。
蒸馏目标
有几种策略可以定义学生模型应如何模仿教师模型:
-
匹配去噪预测: 最直接的方法是训练学生模型,使其在给定带噪声输入 xt 和时间步 t 时,预测与教师模型相同的输出。如果两个模型都预测噪声 ϵ,则蒸馏损失旨在最小化其预测之间的差异:
Ldistill=Ex,ϵ,t[w(t)∣∣ϵθstudent(xt,t)−ϵθteacher(xt,t)∣∣2]
这里,x 是数据样本,ϵ 是采样噪声,t 是时间步,xt=αˉtx+1−αˉtϵ 是带噪声输入,w(t) 是一个可选的加权项,可以优先在特定时间步进行匹配。如果模型预测去噪样本 x0,也可以定义类似的目标。在此过程中,教师模型的参数 (parameter) (θteacher) 保持不变。
-
匹配概率流 ODE 解: 学生模型可以被训练来近似由与教师模型相关联的概率流 ODE 定义的轨迹,而不仅仅是匹配离散步长的输出。这通常包括匹配预测的分数场或速度场。
-
特征层蒸馏: 除了仅仅匹配最终输出外,还可以鼓励学生模型复制教师模型在特定层的内部特征表示。这需要添加辅助损失项,以最小化中间激活之间的差异,这可能提供更丰富的训练信号。
渐进式蒸馏
一个主要挑战是,执行单步生成的学生模型可能难以复制多步教师模型的质量。Salimans & Ho (2022) 的论文“用于扩散模型快速采样的渐进式蒸馏”提出了一种有效的技术来解决此问题。
主要思路是迭代地蒸馏采样过程。
- 阶段 1: 训练一个学生模型 (S1),使其用 一步 完成教师模型 (T) 用 两步 完成的工作。具体来说,S1(xt,t) 被训练来预测 T 从 xt 开始,应用其更新规则两次后会产生的输出。这有效地将所需的采样步数减半 (N→N/2)。
- 阶段 2: 将已训练的学生 S1 视为新的教师模型。训练第二个学生模型 (S2),使其用 一步 完成 S1 用 两步 完成的工作。这进一步将采样步数减半 (N/2→N/4)。
- 重复: 迭代地继续这个过程。每个阶段都将所需函数评估 (NFE) 的数量减半。在 k 个阶段后,最终的学生模型只需要 N/2k 步。
这种渐进式方法使学生模型能够跨多个中间阶段学习复杂的映射,与尝试大幅减少步数的单阶段蒸馏相比,通常能获得更好的样本质量。
示意图展示了渐进式蒸馏过程。每个学生模型学习模拟前一个模型(教师或前一个学生模型)的两步,从而迭代地减少所需的采样步数。
与一致性模型的关系
模型蒸馏,特别是渐进式蒸馏,与一致性模型(在第 5 章中讨论过)具有相同的加速采样的目标。然而,其机制有所不同:
- 一致性模型: 学习一个函数,通过沿着 ODE/SDE 轨迹施加一致性属性,将任何点 (xt,t) 直接映射到轨迹起点 x0。它们可以通过类似蒸馏的过程或独立地进行训练。
- 渐进式蒸馏: 侧重于通过训练学生模型来预测两个教师步长的结果,从而将离散采样步数减半。它不明确地在整个连续轨迹上施加相同的一致性属性。
在实践中,这两种方法都能带来明显的加速,通常能够在 1 到 8 步内完成生成。渐进式蒸馏在保留原始多步采样过程的某些方面可能提供更大的灵活性,而一致性模型则是专门为基于一致性属性的极少步或单步生成而设计。
学生模型的架构选择
蒸馏技术的一个主要优点是选择学生模型架构的灵活性。它无需与教师模型匹配。常见选择包括:
- 使用相同的架构类型(例如 U-Net),但具有更小的深度、宽度(通道数)或更少的注意力头。
- 采用完全不同、更高效的架构。
选择很大程度上取决于目标应用、所需的推理 (inference)速度、可接受的质量权衡以及可用于推理的计算资源。
优点与缺点
优点:
- 速度与效率: 相比于原始教师模型,大幅减少采样时间和计算成本。
- 部署: 使扩散模型能够部署在计算能力有限或对延迟有严格要求的设备上。
- 灵活性: 学生架构可以根据具体需求进行定制。
缺点:
- 质量下降: 蒸馏模型通常在样本质量或多样性方面相比于教师模型有所降低。这种权衡通常在压缩程度更高(步数更少)时更为明显。
- 训练成本: 蒸馏过程本身可能计算量大,需要使用大型教师模型的输出训练一个或多个学生模型。
- 超参数 (parameter) (hyperparameter)调整: 寻找最佳蒸馏目标、学生架构和训练计划需要仔细的实验。
“模型蒸馏为优化扩散模型提供了一系列有价值的技术,并在采样方面进行了算法改进。通过创建更小、更快的学生模型,蒸馏使当前前沿的生成能力在应用中更具实用性,补充了本章稍后讨论的量化 (quantization)和硬件加速等方法。”