虽然精巧的架构奠定了基础,但训练过程本身提供了强大的手段来提升扩散模型的能力。一种较早但很重要的技术,用于将生成过程引向特定的期望属性(例如特定的图像类别),是分类器引导。该方法通过引入外部分类器模型的信息来调整采样轨迹。
想象您有一个在 ImageNet 等数据集上进行无条件训练的扩散模型。在采样过程中,它会生成代表整体数据集的图像。但如果您特别想要一张“金毛犬”的图像怎么办?分类器引导提供了一种在推断时注入这种条件信息的机制,促使采样过程生成分类器识别为目标类别的图像。
核心原理:梯度引导
分类器引导完全在采样阶段运行。它不改变核心扩散模型 ϵθ(xt,t) 的训练,该模型通常训练用于无条件或简单条件地预测时间步 t 时添加的噪声。相反,它使用一个独立的、预训练的分类器 pϕ(y∣xt),该分类器经过专门训练,即使图像 x 带有噪声(在时间步 t),也能预测其类别 y。
在逆扩散过程(采样)的每一步中,我们从带噪声的图像 xt 开始。我们希望计算下一个噪声略少的图像 xt−1。核心思想是根据扩散模型的预测,并结合如何使 xt 根据分类器 pϕ 更有可能被分类为目标类别 y 来调整这一步的方向。
这种“更有可能”在数学上转化为使用分类器对输入图像 xt 的对数概率的梯度。即,我们计算 ∇xtlogpϕ(y∣xt)。这个梯度向量指向输入空间 (xt) 中最能增加分类器对目标类别 y 信心的方向。
数学公式
回想一下,在标准 DDPM 或 DDIM 采样步骤中,扩散模型 ϵθ(xt,t) 预测了可能添加到 xt 以得到它的噪声 ϵ。分类器引导修改了这一预测。
其基本原理与贝叶斯定理和得分匹配有关。得分函数 ∇xtlogp(xt∣y) 代表了在给定条件 y 下,为增加 xt 的似然性而前进的方向。这可以近似为:
∇xtlogp(xt∣y)≈∇xtlogp(xt)+∇xtlogp(y∣xt)
在此,∇xtlogp(xt) 是无条件分布的得分(与扩散模型的输出有关),而 ∇xtlogp(y∣xt) 是由外部分类器 pϕ 提供的得分。
就噪声预测 ϵ 而言,引导调整了扩散模型的输出 ϵθ(xt,t)。用于采样步骤的修改后的噪声预测 ϵ^θ(xt,t,y) 计算如下:
ϵ^θ(xt,t,y)=ϵθ(xt,t)−s⋅σt⋅∇xtlogpϕ(y∣xt)
让我们分解这些组成部分:
- ϵθ(xt,t): 无条件训练的(或基础条件的)扩散模型对当前带噪声图像 xt 和时间步 t 的原始噪声预测。
- pϕ(y∣xt): 预训练分类器对带噪声图像 xt 属于目标类别 y 的概率估计。
- ∇xtlogpϕ(y∣xt): 目标类别 y 的对数概率相对于输入带噪声图像 xt 的梯度。这是来自分类器的“引导信号”。它需要通过分类器网络计算梯度。
- s: 引导尺度(或强度)。这是一个超参数(s≥0),控制分类器梯度对噪声预测的影响强度。当 s=0 时,恢复原始无条件采样。较大的 s 值会更强烈地促使生成过程趋向目标类别 y。
- σt: 与时间步 t 噪声水平相关的缩放因子。通常,这与噪声的标准差有关,例如,在 DDPM 表示法中,σt=1−αˉt,尽管存在变体。该项有助于平衡梯度相对于不同时间步噪声预测的幅度。
生成的 ϵ^θ 随后被用于标准的 DDPM 或 DDIM 更新公式来计算 xt−1。
实现步骤
实现分类器引导涉及以下主要组成部分:
- 预训练扩散模型: 您需要一个标准扩散模型 ϵθ(xt,t),该模型可以无条件训练,或者可能带有与引导目标无关的基础条件。
- 预训练噪声分类器: 这是非常重要的一部分。您需要一个独立的分类器网络 pϕ(y∣x),该网络经过专门训练,能够对被与不同扩散时间步 t 对应的噪声损坏的图像进行分类。训练这个分类器需要通过添加与扩散过程时间表匹配的噪声水平来扩充训练数据。它通常与扩散模型的主干网络结构类似,但输出类别概率。
- 采样循环修改:
- 在采样循环中,对于从 T 到 1 的每个时间步 t:
- 获取当前带噪声样本 xt。确保 xt 需要梯度。
- 获取无条件噪声预测 ϵuncond=ϵθ(xt,t)。
- 将 xt 通过噪声分类器 pϕ,获得所需目标类别 y 的对数概率 logpϕ(y∣xt)。
- 计算此对数概率相对于输入的梯度:g=∇xtlogpϕ(y∣xt)。这通常涉及调用
torch.autograd.grad 或其等效函数。请记住在将 xt 输入扩散模型之前分离它,如果您不希望梯度流经它。
- 计算引导噪声:ϵ^θ=ϵuncond−s⋅σt⋅g。
- 在您选择的采样器(DDPM, DDIM)中使用 ϵ^θ 来计算下一步的去噪估计 xt−1。
- 重复此过程直到获得 x0。
下图展示了单个引导采样步骤中的数据流:
分类器引导单步数据流。带噪声图像 xt 由扩散模型和噪声分类器共同处理。分类器针对目标类别 y 的输出用于计算梯度,该梯度随后被缩放并从扩散模型的噪声预测中减去,从而得到引导噪声 ϵ^θ。
分类器引导的优点
- 直接控制: 提供了一种直接控制生成过程,使其趋向于分类器可识别的特定类别或属性的方式。
- 潜在的质量提升: 与无条件生成相比,有时可以提升目标类别样本的质量和真实感,特别是当扩散模型在特定模式上表现不佳时。
缺点与挑战
- 需要独立的分类器: 主要的缺点是需要训练和维护一个额外的分类器模型。
- 噪声分类器训练: 这个分类器必须处理扩散采样过程中遇到的噪声水平,这使得其训练具有挑战性且计算成本高昂。它需要访问与扩散模型相同的噪声调度和数据。
- 引导尺度调整: 找到最佳引导尺度 s 是非常重要。过低,引导效果微弱。过高,过程可能对分类器过度优化,可能将分类器视为对抗者,生成对分类器看起来很好但却不自然或包含伪影的样本。这通常表现为过度饱和或纹理奇怪的图像。
- 灵活性有限: 尽管对类别有效,但与将条件直接嵌入扩散模型架构(例如,通过交叉注意力)的方法相比,将其应用于详细文本描述的直接性较低。
分类器引导是可控扩散生成中的一个重要步骤。然而,训练噪声分类器相关的实际挑战促使研究人员寻求替代方法。这为**无分类器引导(CFG)**奠定了基础,这是下一节讨论的技术,它巧妙地实现了类似的引导效果,而完全不需要外部分类器模型。