Leveraging the power of pre-trained diffusion models provides an effective route for training consistency models. This method, known as consistency distillation (CD), treats the existing diffusion model as a "teacher" that guides the training of the "student" consistency model. The goal is to transfer the generative capabilities learned by the iterative teacher model into a student model capable of fast, potentially single-step, generation.
The Teacher-Student Framework
In this setup:
- Teacher Model (ϕ): This is a pre-trained, high-performing diffusion model (like a DDPM or DDIM-trained model). Its role is to provide accurate estimates of the solution paths defined by the probability flow ODE associated with the diffusion process. It doesn't generate the final output directly but provides the necessary intermediate steps or score estimates. The teacher model's parameters (ϕ) are frozen during consistency distillation.
- Student Model (θ): This is the consistency model fθ(x,t) we aim to train. It takes a noisy input xt and a timestep t and directly predicts the estimated origin of the trajectory, x^0.
- Target Model (θ−): To stabilize training and improve performance, a separate target network fθ−(x,t) is typically used. This network's parameters (θ−) are an exponential moving average (EMA) of the student model's parameters (θ). It provides the target values for the student model's predictions during training.
The core idea is to train the student model fθ such that its output remains consistent along the trajectories defined by the teacher model ϕ.
The Consistency Distillation Objective
Recall the consistency property: for any pair of points (xt,xt′) on the same ODE trajectory where t′<t, we want f(xt,t)≈f(xt′,t′). Distillation enforces this by minimizing the difference between the student model's output at a later time t and the target model's output at an earlier time t′ on the same trajectory, where the step from xt to xt′ is estimated using the teacher model.
The training process involves sampling pairs of adjacent timesteps (tn+1,tn) from a discretization T=t1>t2>⋯>tN=ϵ>0. For each pair:
- Sample a data point x0∼pdata(x).
- Sample Gaussian noise z∼N(0,I).
- Generate the noisy sample xtn+1 corresponding to time tn+1 using the standard forward process (e.g., xtn+1=αtn+1x0+σtn+1z).
- Use the teacher model ϕ and a one-step ODE solver (like Euler or Heun) to estimate the point xtn on the trajectory that would precede xtn+1. This step typically involves using the teacher's noise prediction ϵ^ϕ(xtn+1,tn+1) or score estimate s^ϕ(xtn+1,tn+1). For instance, using the DDIM update rule:
x^0=αtn+1xtn+1−σtn+1ϵ^ϕ(xtn+1,tn+1)
xtn=αtnx^0+σtnϵ^ϕ(xtn+1,tn+1)
(Note: More sophisticated ODE solvers can be used here for better accuracy).
- Compute the consistency distillation loss:
LCD(θ,θ−;ϕ)=En,x0,z[λ(tn)d(fθ(xtn+1,tn+1),fθ−(xtn,tn))]
Here:
- n is sampled uniformly from {1,…,N−1}.
- fθ(xtn+1,tn+1) is the student model's prediction using the "later" noisy sample.
- fθ−(xtn,tn) is the target model's prediction using the "earlier" sample estimated via the teacher. Crucially, gradients are not propagated through the target network fθ− or the teacher model ϕ.
- d(⋅,⋅) is a distance function measuring the difference between the predictions. Common choices include L2 distance, L1 distance, or perceptual metrics like LPIPS.
- λ(tn) is an optional positive weighting function, often set to 1.
Target Network Updates
The target network parameters θ− are updated periodically using an exponential moving average (EMA) of the student parameters θ:
θ−←μθ−+(1−μ)θ
The momentum parameter μ is typically close to 1 (e.g., 0.99, 0.999). This slow update provides stable targets for the student model, preventing oscillations and improving convergence, similar to techniques used in reinforcement learning and self-supervised learning.
Implementation Considerations
- Timestep Discretization (N): The number of discrete steps N used during training affects the granularity of the consistency being enforced. Larger N provides finer control but increases computational overhead slightly as it determines the possible pairs (tn+1,tn).
- ODE Solver: The choice of ODE solver used to estimate xtn from xtn+1 with the teacher model impacts the accuracy of the target. Higher-order solvers might yield better results at the cost of computation.
- Distance Metric (d): L2 loss is common, but L1 can be more robust to outliers. Perceptual losses like LPIPS can sometimes yield results that align better with human perception, especially for images.
- Architecture: The architecture of the student model fθ often mirrors the teacher model's architecture (e.g., a U-Net or DiT) but is trained with the consistency objective.
Diagram illustrating the consistency distillation training process. Data x0, noise z, and a timestep tn+1 produce xtn+1. The teacher model ϕ helps estimate the prior point xtn on the trajectory. The student model fθ predicts the origin from xtn+1, while the target model fθ− predicts the origin from xtn. The loss minimizes the distance between these predictions, updating only the student model parameters θ. The target parameters θ− are updated via EMA from θ.
Advantages and Disadvantages
Advantages:
- Leverages Powerful Teachers: Can effectively transfer the knowledge from state-of-the-art diffusion models without needing to rediscover the data distribution entirely from scratch.
- Potentially Faster Convergence: Compared to training from scratch (consistency training), distillation can sometimes converge faster as it starts with strong guidance from the teacher.
- High-Quality Results: Distilled consistency models have demonstrated the ability to generate high-fidelity samples in significantly fewer steps than their teacher models.
Disadvantages:
- Dependency on Teacher Model: The performance of the distilled consistency model is inherently limited by the quality of the teacher diffusion model. Any flaws or biases in the teacher may be transferred.
- Requires Pre-trained Model: This approach necessitates having a well-trained diffusion model available, which itself requires significant computational resources and data.
Consistency distillation provides a practical and effective method for obtaining fast generative models by building upon the successes of established diffusion models. It represents a significant step towards mitigating the slow sampling speed that often limits the applicability of diffusion-based generative approaches.