固定的噪声调度,无论是线性、余弦还是定制设计的,都为扩散过程提供了预设路径,但它们都基于一个假设:即逆向步骤中单一、预定的方差演变是最优的。然而,在逆向过程步骤 pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),σt2I) 中,理想方差 σt2 可能因时间步 t 甚至特定数据实例而异。固定方差,通常是采用从正向过程噪声调度(如 βt 或 β~t=1−αˉt1−αˉt−1βt)导出的选择,限制了模型的灵活性。
这种局限性促成了可学习方差调度的出现,这是一种在训练期间由扩散模型自身预测每个逆向步骤所需方差的技术。Nichol 和 Dhariwal 在《改进的去噪扩散概率模型》(2021)中对此方法进行了着重研究,它使得模型能够动态调整逆向过程的随机性。
为什么要学习方差?
学习方差赋予模型更强的表达能力。请思考以下几点:
- 最佳噪声水平: 逆向过程的不同阶段可能需要不同量的噪声。早期步骤(t 较大)可能需要更大的方差,以从纯噪声中产生显著变化;而后期步骤(t 较小)可能需要更小的方差,用于细节微调 (fine-tuning)。可学习方差使模型能够对此进行调整。
- 提升似然: 原始 DDPM 论文固定逆向过程方差主要是为了简化,优化了一个与噪声预测相关的替代目标(Lsimple)。学习方差允许直接优化与数据对数似然的变分下界 (VLB) 相关的内容,可能产生能够更好捕获真实数据分布并获得更高对数似然分数的模型。
- 数据依赖适应: 尽管通常是作为时间步 t 的函数来学习,但该机制也可能根据 xt 进行调整,尽管这不太常见。
方差的参数 (parameter)化与预测
不再固定 σt2,我们对其进行参数化,并让模型预测参数。一种常见方法是内插两种标准固定选择:βt 和 β~t。回想一下,βt 对应于步骤 t 的正向过程方差,而 β~t 旨在当 x0 已知时匹配后验方差 q(xt−1∣xt,x0)。
可学习方差 σθ,t2 可以参数化为:
σθ,t2=exp(vlogβt+(1−v)logβ~t)
其中,v 是神经网络 (neural network)预测的一个参数,通常约束在 0 到 1 之间。网络架构(例如 U-Net 或 Transformer)被修改,以输出一个额外的值(如果需要空间变化的方差,可以是一个值集合,每个像素/块一个,但通常每个时间步预测一个标量 v)代表 v,与用于平均值 μθ(xt,t) 的预测同时进行(通常从噪声预测 ϵθ(xt,t) 中导出)。
扩散模型接收带噪声的输入 xt 和时间步 t 嵌入 (embedding)。其输出被分为两部分,分别预测噪声 ϵθ(确定逆向过程均值)和方差参数 vθ。
调整训练目标
在学习方差时,训练目标需要考虑这一预测。原始 DDPM 的 Lsimple 目标只侧重于预测噪声 ϵ。为了训练方差预测 v,损失函数 (loss function)引入了一个从 VLB 导出的项,通常表示为 Lvlb。该项直接涉及预测的方差 σθ,t2。
一种常见做法是使用混合目标函数:
Lhybrid=Lsimple+λLvlb
其中 Lsimple=Et,x0,ϵ[∣∣ϵ−ϵθ(xt,t)∣∣2] 是标准噪声预测损失,而 Lvlb 是鼓励精确方差预测的项。超参数 (parameter) (hyperparameter) λ 平衡这两个目标。将 λ=0 设置为 0 可恢复标准 DDPM 的固定方差训练。Nichol 和 Dhariwal 发现,一个小的非零 λ(例如 λ=0.001)效果良好,保留了 Lsimple 在样本质量方面的优势,同时从 Lvlb 中获得了似然提升。
实现注意事项
实现可学习方差涉及以下主要修改:
- 模型输出: 调整网络的最后一层(U-Net 或 Transformer),使其输出通道数量是标准 ϵ 预测的两倍。一半表示 ϵθ,另一半表示计算 σθ,t2 所需的参数 (parameter)(例如,用于在 βt 和 β~t 之间进行内插的值 v)。
- 损失函数 (loss function): 实现混合损失 Lhybrid,计算 ϵθ 上的均方误差 (MSE) 损失以及基于预测方差的 Lvlb 项。
- 采样: 在采样过程中,使用从预测 ϵθ 导出的标准均值计算逆向步骤 xt−1,但将噪声分量 z∼N(0,I) 乘以预测的标准差 σθ,t,而不是固定的 σt。
xt−1=μθ(xt,t)+σθ,tz其中 当 t>1 时 z∼N(0,I),否则 z=0
优点与权衡
- 优点:
- 可以显著提高对数似然分数,相较于固定方差模型。
- 可能带来样本质量的适度提升(例如 FID 分数),尽管主要优势通常体现在似然上。
- 为模型提供更大的灵活性,以适应生成过程。
- 权衡:
- 增加模型复杂度,因为网络必须预测额外的参数 (parameter)。
- 混合损失函数 (loss function)增加了训练设置的复杂性。
- 需要仔细调整混合损失中的超参数 (hyperparameter) λ。
可学习方差调度代表了从固定或手动设计调度向前迈进了一步,帮助扩散模型优化逆向过程的一个基本方面。尽管增加了一些复杂性,似然和适应性方面的潜在收获使其成为高级扩散建模工具包中一项有价值的技术,尤其是在精确密度估计与样本质量同等重要时。