如我们所确认,直接计算真实的逆向概率 p(xt−1∣xt) 通常是不可解的。我们的做法是使用参数化分布来近似这种逆向转移,通常是高斯分布,通过神经网络学习得到:pθ(xt−1∣xt)。网络需要估计这种高斯分布的参数,具体而言是其均值 μθ(xt,t) 和方差 Σθ(xt,t)。
虽然网络可以被训练来直接预测均值 μθ(xt,t) 或者去噪后的样本 xt−1,但另一种方法在实践中已被证明非常有效:预测噪声分量 ϵ,即在正向过程中于时间步 t 添加的噪声。
我们来看看这为何有用。回顾一下正向过程中从 x0 直接采样得到 xt 的闭式表达式:
xt=αˉtx0+1−αˉtϵ
这里,ϵ 是一个标准高斯噪声样本,αˉt 则来源于噪声调度。此方程关联了原始数据 x0、带噪声版本 xt 以及为达到该状态而添加的噪声 ϵ。
如果我们的神经网络(我们称之为 ϵθ)能够根据带噪声的输入 xt 和时间步 t 准确预测噪声 ϵ,我们就可以利用这个预测值 ϵθ(xt,t) 来辅助我们对前一状态 xt−1 的估计。
预测噪声如何帮助估计逆向步骤 pθ(xt−1∣xt) 的均值 μθ(xt,t) 呢?最初的去噪扩散概率模型(DDPM)论文指出,逆向转移 p(xt−1∣xt,x0) 的均值可以表示为:
μ~t(xt,x0)=αt1(xt−1−αˉt1−αtϵ)
其中 αt=αˉt/αˉt−1。请注意,此表达式依赖于原始数据 x0(通过 ϵ),而我们在生成过程中是无法获得的。
然而,如果我们知道 xt 和 ϵ,我们可以重新排列第一个方程,得到 x0 的一个估计值:
x0≈x^0=αˉt1(xt−1−αˉtϵθ(xt,t))
通过将网络的噪声预测值 ϵθ(xt,t) 替换 ϵ 代入 μ~t 的方程中,我们得到了近似逆向转移 pθ(xt−1∣xt) 的均值表达式:
μθ(xt,t)=αt1(xt−1−αˉt1−αtϵθ(xt,t))
这建立了一个直接的联系:如果我们的网络 ϵθ(xt,t) 成功预测了在步骤 t 添加的噪声,我们就可以计算去噪步骤 xt→xt−1 所需的均值。方差 Σθ(xt,t) 通常被固定为与正向过程方差相关的值,或者有时也会被学习,但预测噪声主要用于确定均值。
为什么要预测噪声?
通过预测噪声来参数化逆向过程有几个益处:
- 更简单的学习目标: 噪声 ϵ 通常是从一个简单分布(例如标准高斯分布)中提取的。与直接预测数据 x0 或 xt−1 的复杂结构相比,预测这种噪声对神经网络来说可能是一个更容易的任务,尤其是在高噪声水平(t 值较大)时。
- 与损失函数的一致性: 正如我们将在下一章中看到的,扩散模型的训练目标通常简化为正向过程中添加的真实噪声 ϵ 与网络预测噪声 ϵθ(xt,t) 之间的均方误差损失。训练网络直接输出 ϵθ 使网络输出与此常用损失函数完全一致,可能带来更稳定的训练。
- 实践中的成效: 这种噪声预测的设定在许多高性能扩散模型的成功中起到了核心作用,特别是在图像生成方面。
因此,标准做法是训练一个神经网络 ϵθ,它接收带噪声的数据 xt 和时间步 t 作为输入,并输出用于从 x0 生成 xt 的噪声 ϵ 的预测值。这个预测的噪声 ϵθ(xt,t) 随后使我们能够计算近似逆向分布 pθ(xt−1∣xt) 的参数(特别是均值),从而实现逐步生成过程。
神经网络 ϵθ 接收当前的带噪声样本 xt 和时间步 t 作为输入。它的目标是预测最有可能添加到原始数据中以产生 xt 的噪声 ϵθ。这个预测是用于指导逆向去噪步骤的核心组成部分。