无条件生成会生成代表整体训练数据的样本。然而,我们常常需要更具体的输出。设想一下,如果只想生成猫的图像,或者对应数字‘8’的数字图像,我们如何控制已经学习了分布 p(x) 的扩散模型,使其从条件分布 p(x∣y) 生成样本呢?y 代表所需条件(例如 y= '猫' 或 y= '8')。
实现此目的的首批成功方法之一是分类器引导。其主要思路是借助一个单独的、预训练的分类器模型,我们称之为 pϕ(y∣x),其中 ϕ 代表分类器的参数。该分类器经过训练,能够根据输入 x 预测类别标签 y。
然而,在扩散模型的逆向采样过程中,我们处理的不是干净数据 x0,而是不同时间步 t 的含噪中间样本 xt。因此,为了使分类器引导有效,分类器 pϕ(y∣xt) 必须经过训练,即使从含噪输入 xt 中也能识别类别 y。这意味着分类器不仅要在原始数据集(如干净图像)上进行训练,还要在数据的含噪版本上进行训练,以与扩散过程中遇到的噪声水平保持一致。
这个分类器如何引导生成呢?回想一下,逆向过程旨在逼近 p(xt−1∣xt)。我们希望修改此步骤,从 p(xt−1∣xt,y) 中进行采样。使用贝叶斯定理,我们可以写出:
p(xt∣y)=p(y)p(y∣xt)p(xt)
对 xt 取对数再求梯度:
∇xtlogp(xt∣y)=∇xtlogp(y∣xt)+∇xtlogp(xt)
项 ∇xtlogp(xt) 是时间 t 处边缘数据分布的得分函数。扩散模型的噪声预测网络 ϵθ(xt,t) 经过训练,旨在逼近 −σt∇xtlogp(xt)(按噪声水平缩放)。项 ∇xtlogp(y∣xt) 是根据分类器,所需类别 y 对数似然的梯度,在当前含噪样本 xt 处评估。该梯度指出输入空间 (xt) 中使样本在分类器看来更像类别 y 的方向。
分类器引导通过加入该分类器梯度来修改采样步骤。具体而言,在计算逆向步骤 pθ(xt−1∣xt) 的均值 μθ(xt,t) 时,我们使用分类器的梯度对其进行扰动。DDPM 采样步骤中均值的更新规则可以调整如下:
原始预测均值源自预测噪声 ϵθ:
μθ(xt,t)=αt1(xt−1−αˉtβtϵθ(xt,t))
引导均值 μ^θ(xt,t,y) 变为:
μ^θ(xt,t,y)=μθ(xt,t)+s⋅Σt∇xtlogpϕ(y∣xt)
此处:
- Σt 是逆向步骤 pθ(xt−1∣xt) 的协方差矩阵(方差),通常与 βt 相关联。
- s 是引导强度,一个控制分类器影响强度的超参数。较高的 s 值会更强地推动生成朝向目标类别 y。
- ∇xtlogpϕ(y∣xt) 是由分类器 pϕ 计算的梯度。
采样过程随后使用此调整后的均值 μ^θ 进行:
- 使用 U-Net 预测噪声 ϵθ(xt,t)。
- 计算原始均值 μθ(xt,t)。
- 计算所需类别 y 的分类器梯度 ∇xtlogpϕ(y∣xt)。
- 使用引导强度 s 计算引导均值 μ^θ(xt,t,y)。
- 从 N(xt−1;μ^θ(xt,t,y),Σt) 中采样 xt−1。
此过程从 t=T 到 t=1 重复进行。
该图说明了单个逆向步骤中的分类器引导机制。U-Net 预测噪声,而一个单独的分类器则根据目标类别 y 提供梯度。这些被组合起来,并按 s 缩放,以生成用于采样 xt−1 的引导均值。
分类器引导的优点:
- 明确控制: 提供了一种直接方式,可将生成引导至分类器训练的特定属性。
- 使用现有分类器: 理论上可以使用现成的、训练有素的分类器(尽管它们需要具备噪声鲁棒性)。
分类器引导的缺点:
- 需要单独的分类器: 需要训练并维护一个额外的模型 pϕ(y∣xt)。
- 分类器训练: 分类器必须能够应对扩散过程中遇到的噪声水平,增加了训练的复杂性。
- 计算成本: 需要在每个采样步骤中对扩散模型和分类器都运行推理。
- 引导强度调整: 找到合适的引导强度 s 通常需要进行实验。过低时,引导效果不佳;过高时,样本可能变得不真实,或受到分类器梯度利用的对抗性影响。
尽管有效,对一个单独的、噪声感知的分类器的需求促使研究人员开发出在没有这种外部依赖的情况下实现类似引导的方法。这引出了我们即将讨论的下一项技术:无分类器引导(CFG)。