一致性训练 (CT) 提供了一种从零开始训练一致性模型的方法,无需依赖预训练 (pre-training)的扩散模型。这与一致性蒸馏等其他技术形成对比,后者通常用于加速现有扩散模型。这种独立方法直接优化神经网络 (neural network) fheta(x,t),使其满足一致性性质:沿着相同概率流 (PF) ODE 轨迹的点应映射到相同的起点 x0。
主要挑战在于,在无法明确获得由预训练模型定义的 ODE 轨迹的情况下,如何强制执行此性质。独立一致性训练通过使用数值 ODE 求解器以及模型自身在训练中不断变化的预测,巧妙地解决了这个问题。
一致性训练目标
目标是学习一个函数 fθ(x,t),使得对于任何有效时间 t∈[ϵ,T] 和来自 x0 的 ODE 轨迹上的任何点 xt,我们都有 fθ(xt,t)≈x0。此处,ϵ 是接近 0 的一个很小的最小时间步长,而 T 是最大扩散时间。
为此,CT 最小化一个损失函数 (loss function),该函数鼓励模型在估计的 ODE 轨迹上、于相邻时间步评估的输出之间保持一致。考虑离散化时间表 t1,t2,...,tN 中的两个连续时间步 tn 和 tn+1,其中 t1≈ϵ 且 tN=T。设 xtn 和 xtn+1 是相同(估计的)ODE 轨迹上的点。一致性损失强制要求模型从这两个点对 x0 的预测是相似的:
LCT(θ,θ−)=Ex0∼pdata,n∼U(1,N−1),z∼N(0,I)[λ(tn+1)d(fθ(xtn+1,tn+1),fθ−(xtn,tn))]
让我们分解此目标:
- 采样: 我们采样一个真实数据点 x0、一个随机时间步索引 n 和一个噪声向量 (vector) z。
- 轨迹点: 我们需要获得 xtn 和 xtn+1。这些点是通过从 x0 扰动后的点开始,迈出数值 ODE 求解器(如一阶欧拉或二阶Heun)的一步来估计的。例如,对 PF ODE dtdx=21β(t)(∇xlogpt(x)+x) 使用欧拉方法:
- xtn≈x0+α(tn)2/(1−α(tn)2)z(基于扩散过程性质的近似)
- xtn+1 是通过使用 ODE 求解器从 xtn 迈出一步获得的。这需要在 (xtn,tn) 处估计得分 ∇xlogpt(x)。要紧的是,CT 通常使用当前模型对 x0 的估计所隐含的得分或相关技术,有效地引导该过程。
- 一致性函数 fθ: 这是正在训练的神经网络 (neural network)。它接受噪声输入 xt 和时间 t,并预测相应的 x0。
- 目标网络 fθ−: 主网络参数 (parameter) (θ) 的缓慢更新的指数移动平均 (EMA) 用于“目标”预测 fθ−(xtn,tn)。这稳定了训练,类似于强化学习 (reinforcement learning)或 BYOL 中使用的技术。更新规则通常是 θ−←μθ−+(1−μ)θ,其中 μ 是一个接近 1 的动量系数(例如 0.999)。
- 距离度量 d(⋅,⋅): 这衡量了两个预测之间的差异。常见选择包括 L1 损失、L2 损失(均方误差)或诸如 LPIPS 的感知损失。
- 加权函数 λ(tn+1): 此函数根据时间步对损失进行加权。它通常优先考虑在较早时间步(更接近数据)的匹配,或使用源自扩散模型理论的加权方案。
轨迹和得分估计
与蒸馏最主要的不同在于如何获得 xtn 和 xtn+1(相同轨迹上的点)。由于没有教师模型提供得分 ∇xlogpt(x),CT 必须估计它。
一种常见做法是使用得分函数与条件期望 E[x0∣xt] 之间的关系。如果模型 fθ(xt,t) 估计 x0,它就可以用来近似生成 xtn+1 所需的 ODE 求解器步骤的得分。这创建了一个自监督循环,模型同时完善其对轨迹和一致性映射的理解。
训练算法概述
独立一致性训练过程在每次迭代中通常遵循以下步骤:
- 采样数据: 从真实数据分布 pdata 中抽取一个数据点小批量 {x0(i)}。
- 采样时间步: 对于每个 x0(i),采样一个时间索引 n(i)∼U(1,N−1)。令 tn=schedule[n] 且 tn+1=schedule[n+1]。
- 生成轨迹对:
- 生成噪声 z(i)∼N(0,I)。
- 估计 xtn(i)(例如,根据扩散过程定义使用 x0(i) 和 z(i))。
- 估计 (xtn(i),tn) 处的得分,可能使用 fθ 或 fθ−。
- 使用数值 ODE 求解器(例如,Heun 方法的一步)和估计的得分,从 xtn(i) 计算 xtn+1(i)。
- 计算模型输出:
- 在线网络预测:yn+1(i)=fθ(xtn+1(i),tn+1)
- 目标网络预测:yn(i)=fθ−(xtn(i),tn)(阻止梯度通过目标网络)。
- 计算损失: 使用距离度量 d 和加权 λ(tn+1) 计算一致性损失 LCT:
损失=B1i=1∑Bλ(tn+1(i))d(yn+1(i),yn(i))
其中 B 是批量大小。
- 梯度更新: 计算梯度 ∇θLCT 并使用优化器(例如 Adam)更新在线网络参数 (parameter) θ。
- 更新目标网络: 使用 EMA 更新目标网络参数 θ−:θ−←μθ−+(1−μ)θ。
此图说明了单个数据点的独立一致性训练循环。该过程包括采样数据和时间、使用得分估计沿着估计的 ODE 轨迹生成一对点、计算在线网络和目标网络的输出,以及根据一致性损失更新网络。
架构考量
独立 CT 中使用的网络架构 fθ(x,t) 通常类似于标准扩散模型或一致性蒸馏中采用的架构,例如 U-Net 变体或 Transformer(如 DiT)。主要要求是能够处理噪声输入 xt 和时间嵌入 (embedding) t 以产生 x0 的估计。调整可能包括修改时间嵌入方式或条件信息(如果有)的整合方式。
与蒸馏的比较
- 独立训练的优点:
- 独立性:不需要预训练 (pre-training)的扩散模型,节省了首先训练一个模型相关的计算成本和时间。
- 较高保真度的可能性:可能避免蒸馏过程中固有的限制,即学生模型试图模仿一个可能不完美的教师。
- 独立训练的缺点:
- 稳定性:可能比蒸馏更难稳定,因为它依赖于引导得分估计。需要仔细调整学习率、EMA 衰减率 (μ) 和损失加权 (λ(t)) 等超参数 (parameter) (hyperparameter)。
- 收敛速度:与蒸馏相比,可能需要更多的训练迭代才能收敛,而蒸馏从强大的教师模型获得指导。
独立一致性训练代表着高效生成建模向前推进了一大步,使得能够直接从数据创建快速的、少步甚至单步生成模型。尽管与蒸馏相比它带来了独特的训练挑战,但其独立于预训练模型的特性使其成为生成建模工具箱中一种很有吸引力的有效方法。