无分类器引导(CFG)能够在不需要独立分类器模型的情况下实现条件生成。本次动手练习展示了如何修改标准扩散模型采样循环以加入CFG。我们假设您已拥有一个预训练的条件扩散模型(如U-Net或DiT)和一个基本的采样函数(例如实现DDIM的函数)。正如所讨论的,CFG的核心思想是在反向过程中,在每个时间步 $t$ 计算两个预测:一个使用引导条件 $c$(例如文本嵌入、类别标签)的条件预测 $\epsilon_\theta(x_t, c)$,以及一个使用空条件 $\emptyset$ 的无条件预测 $\epsilon_\theta(x_t, \emptyset)$。然后,使用引导比例 $s$(在实现中常表示为 $w$)将它们组合起来,以指引生成过程:$$ \tilde{\epsilon}\theta(x_t, c) = \epsilon\theta(x_t, \emptyset) + s \cdot (\epsilon_\theta(x_t, c) - \epsilon_\theta(x_t, \emptyset)) $$这个调整后的噪声估计 $\tilde{\epsilon}_\theta(x_t, c)$ 随后用于采样器的更新步骤。比例 $s=0$ 对应于无条件生成,而 $s=1$ 仅使用条件预测(假设模型经过条件Dropout训练)。$s > 1$ 的值会放大引导信号。前提条件一个预训练的条件扩散模型,能够接受条件信息 $c$ 和空条件 $\emptyset$ 的表示。经过条件Dropout(训练期间随机设置 $c=\emptyset$)训练的模型是合适的。一个实现扩散采样器(例如DDIM)的函数。条件信息 $c$(例如文本嵌入)和一个对应的空条件张量 $\emptyset$。修改采样循环我们假设您有一个标准的DDIM采样函数,其结构大致如下(简化的类似Python的伪代码):def ddim_sample_loop(model, x_T, timesteps, condition, eta=0.0): x_t = x_T for t_idx, t in enumerate(timesteps): # 获取当前和上一个时间步 time_tensor = torch.tensor([t], device=x_t.device) prev_t = timesteps[t_idx + 1] if t_idx < len(timesteps) - 1 else -1 # 1. 使用模型预测噪声 predicted_noise = model(x_t, time_tensor, condition) # 原始预测 # 2. 计算 x_0 预测(使用DDIM公式组件,如alpha_t) # ... 计算 pred_x0 ... # 3. 计算指向 x_t 的方向 # ... 计算 dir_xt ... # 4. 计算随机性噪声(如果 eta > 0) # ... 计算 sigma_t 和 random_noise ... # 5. 计算下一个样本 x_{t-1} x_prev = pred_x0 + dir_xt + sigma_t * random_noise x_t = x_prev return x_t为了实现CFG,我们需要修改步骤1,即模型预测噪声的部分:准备输入: 确保您同时拥有目标条件 $c$ 和空条件 $\emptyset$。空条件通常是与 $c$ 形状相同的零张量,或者是为无条件情况学习到的特定嵌入。批量输入(可选但高效): 如果处理多个样本或模型支持批量操作,通常可以将条件和无条件输入沿批处理维度连接起来,以便在单个模型前向传递中执行两个预测。对于单个样本,这意味着创建一个大小为2的批次。执行预测: 调用模型两次(或一次使用批处理):noise_pred_cond = model(x_t, time_tensor, c)noise_pred_uncond = model(x_t, time_tensor, null_condition)组合预测: 使用引导比例 $s$ 应用CFG公式:guidance_scale = spredicted_noise = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)使用组合预测: 在原始采样循环的后续步骤(2-5)中使用这个 predicted_noise。实现示例(PyTorch风格)下面是修改后的循环部分可能的样子,假设 model 接收 x_t、time_tensor 和 cond 作为输入,并且提供了 guidance_scale ($s$):def ddim_sample_loop_cfg(model, x_T, timesteps, condition, null_condition, guidance_scale, eta=0.0): x_t = x_T batch_size = x_t.shape[0] # 假设 x_T 的形状为 [B, C, H, W] for t_idx, t in enumerate(timesteps): time_tensor = torch.full((batch_size,), t, device=x_t.device, dtype=torch.long) prev_t = timesteps[t_idx + 1] if t_idx < len(timesteps) - 1 else -1 # 1. 使用 CFG 预测模型噪声 # 高效地预测条件和无条件噪声 # 需要模型能够处理批量条件 model_input = torch.cat([x_t] * 2) # 复制输入用于条件/无条件 time_input = torch.cat([time_tensor] * 2) condition_input = torch.cat([condition, null_condition]) # 批量条件 # 单次模型调用以提高效率 noise_pred_combined = model(model_input, time_input, condition_input) # 分割预测结果 noise_pred_cond, noise_pred_uncond = noise_pred_combined.chunk(2) # 应用 CFG 公式 predicted_noise = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) # --- DDIM 步骤的其余部分(步骤 2-5) --- # 根据 t 和 prev_t 计算 alpha_t, alpha_t_prev, sigma_t # (假设这些值已预计算或从噪声调度中计算得出) alpha_t = get_alpha(t) alpha_t_prev = get_alpha(prev_t) if prev_t >= 0 else 1.0 sigma_t = eta * torch.sqrt((1 - alpha_t_prev) / (1 - alpha_t) * (1 - alpha_t / alpha_t_prev)) # 2. 计算 x_0 预测 pred_x0 = (x_t - torch.sqrt(1.0 - alpha_t) * predicted_noise) / torch.sqrt(alpha_t) # 可选:将预测的 x0 限制在 [-1, 1] 或其他有效范围 # pred_x0 = torch.clamp(pred_x0, -1.0, 1.0) # 3. 计算指向 x_t 的方向 dir_xt = torch.sqrt(1.0 - alpha_t_prev - sigma_t**2) * predicted_noise # 4. 计算随机性噪声(如果 eta > 0) random_noise = torch.randn_like(x_t) if eta > 0 and prev_t >= 0 else torch.zeros_like(x_t) # 5. 计算下一个样本 x_{t-1} x_prev = torch.sqrt(alpha_t_prev) * pred_x0 + dir_xt + sigma_t * random_noise x_t = x_prev # --- DDIM 步骤结束 --- return x_t # 返回最终生成的样本 x_0注意: get_alpha(t) 函数从时间步 $t$ 的噪声调度中获取方差的累积乘积 $\bar{\alpha}_t$。具体实现取决于您的噪声调度是如何定义的。调整引导比例 ($s$)guidance_scale ($s$) 是一个超参数,它控制着样本质量/多样性与对条件 $c$ 的遵循程度之间的权衡。低 $s$ 值(例如1.0 - 3.0): 样本倾向于更多样化,并且在纯图像质量方面可能具有更高的保真度,但对条件 $c$ 的遵循可能不那么严格。$s=0$ 是纯粹的无条件生成。中等 $s$ 值(例如4.0 - 8.0): 通常能提供良好的平衡。样本能很好地遵循条件,而不会牺牲太多质量。这是许多文本到图像模型的常用范围。高 $s$ 值(例如9.0 - 15.0+): 样本强烈遵循条件 $c$。然而,高值有时可能导致饱和、伪影或多样性降低,因为模型被强烈地推向条件预测。需要通过实验来找到适用于您的特定模型、数据集和任务的最佳 $s$ 值。您可以生成不同 $s$ 值的样本,并进行定性评估或使用合适的指标进行评估。{"data": [{"x": [1, 2, 3, 5, 7, 10, 15], "y": [5, 6, 7, 8.5, 9, 8, 7], "name": "条件遵循度", "type": "scatter", "mode": "lines+markers", "marker": {"color": "#4263eb"}, "line": {"color": "#4263eb"}}, {"x": [1, 2, 3, 5, 7, 10, 15], "y": [8.5, 8.5, 8, 7.5, 7, 6, 5], "name": "样本多样性/质量", "type": "scatter", "mode": "lines+markers", "marker": {"color": "#12b886"}, "line": {"color": "#12b886"}}], "layout": {"title": {"text": "引导比例 (s) 的权衡示意图"}, "xaxis": {"title": {"text": "引导比例 (s)"}}, "yaxis": {"title": {"text": "主观评分(越高越好)"}, "range": [0, 10]}, "legend": {"yanchor": "top", "y": 0.99, "xanchor": "left", "x": 0.01}, "margin": {"l": 50, "r": 20, "t": 50, "b": 40}}}该图主观地说明了增加引导比例 s 通常如何提高条件遵循度(例如,匹配文本提示),同时在高值下可能降低整体样本多样性或引入伪影。最佳值需要在这些方面取得平衡。其他考量计算成本: CFG 使每个采样步骤的计算成本大约是标准条件采样的两倍,因为它需要两次模型前向传递(或一次通过翻倍的批次)。模型训练: 当模型经过条件Dropout训练时,此技术效果最佳,使其能够有效地学习条件和无条件的生成路径。采样器交互: CFG 调整通常在采样器(DDIM、DPM-Solver等)使用噪声预测 ($\epsilon$) 之前 应用。采样器逻辑的其余部分通常保持不变。通过实现CFG,您可以获得一种在扩散模型中控制条件生成的有效方法,而无需依赖外部分类器,直接使用生成模型内部学到的表示。实验引导比例是为您的特定应用获得期望输出的重要部分。