趋近智
实施无分类器引导(CFG)于扩散模型采样过程。为实现此目的,提供实际示例和代码结构。假设已拥有一个能够接受条件信息预训练 (pre-training)的U-Net模型。
请记住,CFG引导生成朝向条件(例如类别标签或文本嵌入 (embedding)),而无需单独的分类器模型。它通过运用扩散模型执行条件预测和无条件预测的能力来实现此目标。在每个时间步的采样过程中,我们计算:
这两个预测随后使用引导尺度进行组合:
这个是用于计算下一个噪声更少的状态的引导噪声估计。引导尺度控制生成与条件的符合程度。的值会忽略条件,导致无条件采样。增加会更强烈地促使生成朝向条件。
在开始采样循环之前,您需要准备好条件输入和空条件。
假设y_cond包含您期望输出的条件向量(例如,“猫”的嵌入或类别7),而y_null包含空条件的向量。
核心修改发生在采样循环内部(无论使用DDPM还是DDIM)。这是一个简化的类似Python的伪代码结构,假定使用类似PyTorch的框架,并且有一个函数get_denoised_xt_minus_1,它在给定、和预测噪声的情况下执行标准的反向步骤(如DDPM论文中的公式11或12,或DDIM更新):
# 假设model是您的U-Net,scheduler包含噪声调度信息
# x_t 开始时为纯噪声:x_T ~ N(0, I)
# timesteps是一个时间步的列表/张量,例如 [999, 998, ..., 0]
# y_cond是期望输出的条件向量
# y_null是空条件向量
# w是引导尺度(例如,7.5)
x_t = torch.randn_like(initial_sample_shape) # 从随机噪声x_T开始
for t_val in timesteps:
t_tensor = torch.tensor([t_val] * batch_size, device=x_t.device)
# 确保x_t在框架细节需要时,为模型输入要求梯度,
# 但我们通常在推理时不需要梯度。
# 使用torch.no_grad()是提高效率的常见做法。
with torch.no_grad():
# 1. 预测条件输入的噪声
pred_noise_cond = model(x_t, t_tensor, y_cond)
# 2. 预测无条件输入的噪声
pred_noise_uncond = model(x_t, t_tensor, y_null)
# 3. 使用CFG公式组合预测结果
guided_noise = pred_noise_uncond + w * (pred_noise_cond - pred_noise_uncond)
# 4. 使用引导噪声计算x_{t-1}
# 这一步取决于您是使用DDPM还是DDIM采样逻辑
# 假设一个函数封装了反向步骤的示例:
x_t = scheduler.step(guided_noise, t_val, x_t) # 将x_t更新为x_{t-1}
# 循环后的最终结果是x_0(生成的样本)
generated_sample = x_t
循环中的步骤:
t。y_cond传递给模型。y_null传递给模型。w计算guided_noise。guided_noise来计算。为下一次迭代更新x_t。对从到0的所有时间步重复此操作。最终的x_t将是您生成的样本。
w的选择很大程度上影响输出。
w值(例如,0或1):生成受条件的约束较少。如果,则是纯粹的无条件生成。如果,它遵循学习到的条件分布,但可能缺乏强烈的符合性。样本可能多样,但与提示y的对齐 (alignment)程度较低。w值(例如,3到10):通常是最佳点。在与条件y的符合性、整体样本质量和多样性之间取得平衡。生成的图像清晰地反映了条件。w值(例如,15+):强烈符合条件,但样本可能变得不那么多样,可能出现饱和或伪影。模型可能会过度强调与条件相关的特征。尝试不同w值是常见的做法,以找到特定模型和任务的最佳权衡。
让我们直观地看看改变w如何影响使用CFG在MNIST上训练的扩散模型生成数字“8”。
这张示意图展现了典型的权衡。随着引导尺度
w的增加,与条件(生成数字“8”)的符合度通常会提高(蓝线),但样本多样性和潜在的整体质量在某个点之后可能会下降(橙线),有时在极高值时会导致伪影。绿色阴影区域表示通常能找到良好平衡的常见范围。
重要的是要记住,采样过程中的CFG依赖于模型经过专门训练以处理条件和无条件输入。这通常在训练阶段通过使用条件丢弃来实现:
这迫使模型学习如何在特定条件引导下和在没有提供条件(使用空嵌入)时预测噪声。如果没有这种训练策略,模型将不知道如何解释空条件,并且CFG公式将不会产生有意义的引导。
通过实现此处描述的引导采样循环,并运用经过条件丢弃训练的模型,您可以有效地引导扩散过程生成符合您期望条件的输出。这很大程度上扩展了扩散模型提供的创作控制。
这部分内容有帮助吗?
© 2026 ApX Machine LearningAI伦理与透明度•