Okay, let's translate the theory of consistency distillation into practice. This section provides a hands-on walkthrough for implementing a basic consistency distillation process. We'll take a pre-existing diffusion model (our "teacher") and train a "student" consistency model to approximate its outputs rapidly, aiming for single-step generation.
Remember from the chapter introduction, the core idea is to enforce the consistency property: f(xt,t)≈x0 for all t along a trajectory defined by the probability flow ODE. Distillation achieves this by training a student network fθ(x,t) to match the output of a target network fθ−(x′,t′) where x and x′ are adjacent points on the same trajectory, and θ− represents slowly updated weights (EMA of θ) for stability.
We assume you have access to:
teacher_model(xt, t)
.Our goal is to train a student consistency model, student_model(xt, t)
, often initialized with the same architecture as the teacher. We also need an exponential moving average (EMA) version of the student model, target_model(xt, t)
, whose weights θ− are updated slowly based on the student's weights θ.
# Example setup (PyTorch-like)
import torch
import torch.nn.functional as F
from copy import deepcopy
from tqdm import tqdm # For progress visualization
# Assume teacher_model is pre-loaded (predicts epsilon)
# teacher_model.eval() # Set teacher to evaluation mode
# Initialize student model (same architecture as teacher)
student_model = deepcopy(teacher_model)
student_model.train() # Set student to training mode
# Initialize target model with student's initial weights
target_model = deepcopy(student_model)
target_model.eval() # Target model is only for inference
# Optimizer for the student model
optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-4)
# Hyperparameters
num_training_steps = 100000
batch_size = 64
ema_decay = 0.99 # Typical EMA decay rate
N = 100 # Number of discretization steps for training
T = teacher_model.num_timesteps # Max timestep from teacher
# Loss function (e.g., L2 distance)
def consistency_loss_fn(online_pred, target_pred):
return F.mse_loss(online_pred, target_pred)
# Function to get x0 prediction from epsilon prediction
def get_x0_from_epsilon(xt, epsilon, t, alphas_cumprod):
alpha_t_cumprod = alphas_cumprod[t].view(-1, 1, 1, 1) # Ensure correct shape
return (xt - torch.sqrt(1.0 - alpha_t_cumprod) * epsilon) / torch.sqrt(alpha_t_cumprod)
# Assume 'get_dataloader()' provides batches of x0
dataloader = get_dataloader(batch_size)
# Assume 'alphas_cumprod' contains the cumulative product of (1 - beta_t)
# from the teacher's noise schedule
alphas_cumprod = teacher_model.alphas_cumprod.to(device)
The core of consistency distillation involves iteratively sampling pairs of adjacent timesteps, computing the corresponding noisy samples, and minimizing the difference between the student's prediction at the earlier time and the target model's prediction at the later time.
Here's a breakdown of one training step:
# Training Loop Snippet (Simplified)
for step in tqdm(range(num_training_steps)):
optimizer.zero_grad()
x0 = next(iter(dataloader)).to(device) # 1. Sample data
# 2. Sample timesteps (indices n from 1 to N-1)
n = torch.randint(1, N, (batch_size,), device=device)
t_i = (n / N) * T
t_i_plus_1 = ((n + 1) / N) * T
# Ensure integer timesteps if model expects discrete steps
t_idx_i = n.long() # Or map continuous t to discrete indices
t_idx_i_plus_1 = (n + 1).long()
# 3. Generate noisy samples (using teacher's schedule logic)
noise_i = torch.randn_like(x0)
noise_i_plus_1 = torch.randn_like(x0) # Often re-use noise for variance reduction
xt_i = get_noisy_version(x0, t_idx_i, noise_i) # Function based on DDPM forward process
xt_i_plus_1 = get_noisy_version(x0, t_idx_i_plus_1, noise_i_plus_1) # ditto
# 4. Get Target Prediction (using target_model and converting epsilon->x0)
with torch.no_grad():
target_epsilon = target_model(xt_i_plus_1, t_idx_i_plus_1)
target_x0_pred = get_x0_from_epsilon(xt_i_plus_1, target_epsilon, t_idx_i_plus_1, alphas_cumprod)
# 5. Get Student Prediction (using student_model and converting epsilon->x0)
student_epsilon = student_model(xt_i, t_idx_i)
student_x0_pred = get_x0_from_epsilon(xt_i, student_epsilon, t_idx_i, alphas_cumprod)
# 6. Calculate Loss
loss = consistency_loss_fn(student_x0_pred, target_x0_pred)
# 7. Update Student
loss.backward()
optimizer.step()
# 8. Update Target Network (EMA)
for param, target_param in zip(student_model.parameters(), target_model.parameters()):
target_param.data.mul_(ema_decay).add_(param.data, alpha=1 - ema_decay)
if step % 1000 == 0:
print(f"Step: {step}, Loss: {loss.item()}")
# Optional: Save checkpoint, generate sample images
The diagram below illustrates the flow for a single training step:
A diagram showing the data flow during one step of consistency distillation training. Inputs (data, time, noise) are used to generate adjacent noisy samples, which feed into the student and target models. The loss compares their x0 estimates, driving updates to the student weights (θ) and the EMA target weights (θ−).
Once the student_model
(or technically, often the final target_model
containing the EMA weights is preferred for inference) is trained, sampling becomes remarkably simple:
# Single-step sampling
consistency_model = target_model # Use the EMA model for inference
consistency_model.eval()
with torch.no_grad():
z = torch.randn(num_samples, *data_shape).to(device) # Sample noise (x_T)
t_max = torch.full((num_samples,), T-1, dtype=torch.long, device=device) # Max timestep index
# Get epsilon prediction at T
pred_epsilon = consistency_model(z, t_max)
# Convert to x0 prediction
generated_x0 = get_x0_from_epsilon(z, pred_epsilon, t_max, alphas_cumprod)
# 'generated_x0' contains the final samples.
That's it! A single forward pass generates a sample. For potentially higher quality at the cost of slightly more computation, multi-step sampling can be used, involving intermediate steps similar to DDIM but using the consistency function fθ−. However, the primary appeal lies in this drastic reduction offered by single-step generation.
ema_decay
controls how quickly the target network adapts. Values close to 1 (e.g., 0.99, 0.999) provide stability.This practical exercise provides a foundation for understanding how consistency distillation works. By implementing this basic version, you gain insight into the mechanics of training models for rapid, few-step generation, a significant advancement in making diffusion models more practical for real-time applications.
© 2025 ApX Machine Learning