趋近智
我们将一致性蒸馏的理论付诸实践。本节将详细介绍如何实现一个基本的一致性蒸馏过程。我们将使用一个预先存在的扩散模型(我们的“教师模型”),并训练一个“学生”一致性模型,以快速近似其输出,目标是实现单步生成。
一致性特性的强制执行是其主要思想: 适用于由概率流ODE定义的轨迹上的所有 。蒸馏通过训练学生网络 来匹配目标网络 的输出,其中 和 是同一轨迹上的相邻点,而 表示缓慢更新的权重( 的EMA)以保证稳定性。
我们假定您已具备以下条件:
teacher_model(xt, t)。我们的目标是训练一个学生一致性模型 student_model(xt, t),它通常使用与教师模型相同的架构进行初始化。我们还需要一个学生模型的指数移动平均(EMA)版本 target_model(xt, t),其权重 会根据学生模型的权重 缓慢更新。
# 示例设置(类PyTorch)
import torch
import torch.nn.functional as F
from copy import deepcopy
from tqdm import tqdm # 用于进度显示
# 假定teacher_model已预加载(预测epsilon)
# teacher_model.eval() # 将教师模型设置为评估模式
# 初始化学生模型(与教师模型架构相同)
student_model = deepcopy(teacher_model)
student_model.train() # 将学生模型设置为训练模式
# 使用学生模型的初始权重初始化目标模型
target_model = deepcopy(student_model)
target_model.eval() # 目标模型仅用于推断
# 学生模型的优化器
optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-4)
# 超参数
num_training_steps = 100000
batch_size = 64
ema_decay = 0.99 # 典型的EMA衰减率
N = 100 # 训练时的离散化步数
T = teacher_model.num_timesteps # 教师模型的最大时间步
# 损失函数(例如,L2距离)
def consistency_loss_fn(online_pred, target_pred):
return F.mse_loss(online_pred, target_pred)
# 从epsilon预测获取x0预测的函数
def get_x0_from_epsilon(xt, epsilon, t, alphas_cumprod):
alpha_t_cumprod = alphas_cumprod[t].view(-1, 1, 1, 1) # 确保正确的形状
return (xt - torch.sqrt(1.0 - alpha_t_cumprod) * epsilon) / torch.sqrt(alpha_t_cumprod)
# 假定'get_dataloader()'提供x0批次数据
dataloader = get_dataloader(batch_size)
# 假定'alphas_cumprod'包含教师模型噪声调度中的(1 - beta_t)的累积乘积
alphas_cumprod = teacher_model.alphas_cumprod.to(device)
一致性蒸馏的核心在于迭代采样相邻时间步对,计算相应的含噪样本,并最小化学生模型在较早时间点的预测与目标模型在较晚时间点的预测之间的差异。
以下是单个训练步骤的分解:
# 训练循环代码片段(简化版)
for step in tqdm(range(num_training_steps)):
optimizer.zero_grad()
x0 = next(iter(dataloader)).to(device) # 1. 采样数据
# 2. 采样时间步(索引n从1到N-1)
n = torch.randint(1, N, (batch_size,), device=device)
t_i = (n / N) * T
t_i_plus_1 = ((n + 1) / N) * T
# 如果模型需要离散步,确保时间步为整数
t_idx_i = n.long() # 或者将连续的t映射到离散索引
t_idx_i_plus_1 = (n + 1).long()
# 3. 生成含噪样本(使用教师模型的调度逻辑)
noise_i = torch.randn_like(x0)
noise_i_plus_1 = torch.randn_like(x0) # 通常重新使用噪声以减少方差
xt_i = get_noisy_version(x0, t_idx_i, noise_i) # 基于DDPM前向过程的函数
xt_i_plus_1 = get_noisy_version(x0, t_idx_i_plus_1, noise_i_plus_1) # 同上
# 4. 获取目标预测(使用target_model并转换epsilon->x0)
with torch.no_grad():
target_epsilon = target_model(xt_i_plus_1, t_idx_i_plus_1)
target_x0_pred = get_x0_from_epsilon(xt_i_plus_1, target_epsilon, t_idx_i_plus_1, alphas_cumprod)
# 5. 获取学生预测(使用student_model并转换epsilon->x0)
student_epsilon = student_model(xt_i, t_idx_i)
student_x0_pred = get_x0_from_epsilon(xt_i, student_epsilon, t_idx_i, alphas_cumprod)
# 6. 计算损失
loss = consistency_loss_fn(student_x0_pred, target_x0_pred)
# 7. 更新学生模型
loss.backward()
optimizer.step()
# 8. 更新目标网络(EMA)
for param, target_param in zip(student_model.parameters(), target_model.parameters()):
target_param.data.mul_(ema_decay).add_(param.data, alpha=1 - ema_decay)
if step % 1000 == 0:
print(f"Step: {step}, Loss: {loss.item()}")
# 可选:保存检查点,生成样本图像
下图显示了单个训练步骤的数据流:
图中显示了一致性蒸馏训练单步的数据流。输入(数据、时间、噪声)用于生成相邻的含噪样本,这些样本被送入学生模型和目标模型。损失函数比较它们对 的估计,从而驱动学生权重()和EMA目标权重()的更新。
一旦 student_model(或者从技术上讲,通常用于推断的是包含EMA权重的最终 target_model)训练完成,采样过程将变得非常简单:
# 单步采样
consistency_model = target_model # 使用EMA模型进行推断
consistency_model.eval()
with torch.no_grad():
z = torch.randn(num_samples, *data_shape).to(device) # 采样噪声 (x_T)
t_max = torch.full((num_samples,), T-1, dtype=torch.long, device=device) # 最大时间步索引
# 获取T时的epsilon预测
pred_epsilon = consistency_model(z, t_max)
# 转换为x0预测
generated_x0 = get_x0_from_epsilon(z, pred_epsilon, t_max, alphas_cumprod)
# 'generated_x0' 包含最终的样本。
就是这样!单次前向传播即可生成样本。为了在略微增加计算成本的情况下获得更高质量,可以使用多步采样,其中涉及类似于DDIM的中间步骤,但使用的是一致性函数 。然而,其主要吸引力在于单步生成所带来的显著计算量减少。
ema_decay 控制目标网络适应的速度。接近1的值(例如0.99、0.999)提供稳定性。本实践练习为理解一致性蒸馏的工作原理奠定了基础。通过实现这个基本版本,您可以更好地了解训练模型以进行快速、少步生成的机制,这是使扩散模型在实时应用中更具实用性的重要进展。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造