趋近智
训练扩散模型时的一个重要决定是,确定底层神经网络 (neural network)应该预测什么。给定含噪声的输入 和时间步 ,模型需要估算与去噪过程相关的某个量。两种最常见的参数 (parameter)化方式是预测被添加的噪声 (),或预测原始干净数据 ()。这个选择会影响损失函数 (loss function)、训练动态,并可能影响最终的采样质量。
这是最初的去噪扩散概率模型(DDPM)论文中提出的标准方法。模型 被训练来预测噪声 ,该噪声是从标准高斯分布中采样并添加到原始数据 中以生成 ,其依据是前向过程方程:
这里, 是噪声调度方差直到时间 的累积乘积。目标函数通常是预测噪声 与用于生成 的实际噪声 之间的简化均方误差(MSE)损失:
期望是针对随机时间步 、初始数据样本 和采样噪声 计算的。
优点:
在采样(逆向过程)期间,预测噪声 用于估算趋向噪声较小状态的方向,通常是通过先估算预测的 (记作 ),然后将其用于DDPM或DDIM更新步骤。
另一种方法是将模型(我们称之为 )参数 (parameter)化,使其能够从含噪声的输入 和时间步 直接预测原始干净数据 。相应的MSE损失函数 (loss function)旨在最小化预测的 与真实 之间的差异:
优点:
参数化之间的关联:
这两种参数化方式在数学上是相关联的。给定前向过程方程 ,我们可以用一种预测表示另一种预测:
如果模型预测噪声 ,则对应的 预测为:
如果模型预测干净数据 ,则对应的 预测为:
这些关系表明,选择一种参数化方式隐式定义了另一种。然而,训练网络以预测其中一个量与预测另一个量会直接影响梯度和表现,从而可能导致不同的训练动态和最终模型表现。
预测和预测方法的比较。核心模型架构通常是相同的,但损失计算的目标变量不同。
import torch
import torch.nn.functional as F
def get_sqrt_alphas_cumprod(alphas):
"""获取累积乘积的辅助函数"""
return torch.sqrt(torch.cumprod(alphas, dim=0))
def get_sqrt_one_minus_alphas_cumprod(alphas):
"""获取 sqrt(1 - alpha_bar) 的辅助函数"""
return torch.sqrt(1.0 - torch.cumprod(alphas, dim=0))
# --- 示例设置 ---
T = 1000
betas = torch.linspace(0.0001, 0.02, T) # 示例线性调度
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
# --- 假设我们每个批次项目有以下变量 ---
# x_start: 原始干净图像 [B, C, H, W]
# noise: 采样高斯噪声 (epsilon) [B, C, H, W]
# t: 采样时间步 [B]
# model: 您的U-Net或Transformer模型
# 提取批次时间步 t 的调度值
sqrt_alpha_bar_t = sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
sqrt_one_minus_alpha_bar_t = sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
# 计算含噪声图像 x_t
x_t = sqrt_alpha_bar_t * x_start + sqrt_one_minus_alpha_bar_t * noise
# 获取模型预测
# model_output 形状: [B, C, H, W]
model_output = model(x_t, t)
# --- 损失计算 ---
# 1. Epsilon预测损失
target_eps = noise
loss_eps = F.mse_loss(model_output, target_eps)
# 使用 loss_eps 进行反向传播
# 2. x0预测损失
target_x0 = x_start
loss_x0 = F.mse_loss(model_output, target_x0)
# 使用 loss_x0 进行反向传播
# --- 采样考量 (DDIM步骤示例) ---
# 如果模型预测 epsilon (model_output = predicted_eps):
predicted_x0_from_eps = (x_t - sqrt_one_minus_alpha_bar_t * model_output) / sqrt_alpha_bar_t
# 在DDIM更新公式中使用 predicted_x0_from_eps
# 如果模型预测 x0 (model_output = predicted_x0):
predicted_eps_from_x0 = (x_t - sqrt_alpha_bar_t * model_output) / sqrt_one_minus_alpha_bar_t
# 在DDIM更新公式中使用 predicted_eps_from_x0 (或修改DDIM以直接使用 predicted_x0)
Python代码片段,说明了预测和预测损失的目标计算差异,以及如何在采样时从一种预测推导出另一种预测。
虽然预测是完善且通常更稳定的默认选项,预测提供了一个有效的替代方案。尝试预测可能是有益的,如果:
了解其他参数化方式也很重要,例如下一节将讨论的预测,它旨在结合和预测的优点,特别是改进了模型在不同噪声水平下输出的尺度调整方式。了解预测与的含义,有助于理解这些更高级的技术。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造