While consistency distillation offers a way to accelerate existing diffusion models, Consistency Training (CT) provides a method to train consistency models from scratch, without requiring a pre-trained diffusion model. This standalone approach directly optimizes a neural network fθ(x,t) to satisfy the consistency property: points along the same Probability Flow (PF) ODE trajectory should map to the same origin x0.
The core challenge lies in enforcing this property without having explicit access to the ODE trajectories defined by a pre-trained model. Standalone consistency training cleverly addresses this by using numerical ODE solvers and the model's own evolving predictions during training.
The Consistency Training Objective
The objective is to learn a function fθ(x,t) such that for any valid time t∈[ϵ,T] and any point xt on an ODE trajectory originating from x0, we have fθ(xt,t)≈x0. Here, ϵ is a small minimum time step close to 0, and T is the maximum diffusion time.
To achieve this, CT minimizes a loss function that encourages consistency between outputs of the model evaluated at adjacent time steps along estimated ODE trajectories. Consider two consecutive time steps tn and tn+1 from a discretized schedule t1,t2,...,tN, where t1≈ϵ and tN=T. Let xtn and xtn+1 be points on the same (estimated) ODE trajectory. The consistency loss enforces that the model's predictions for x0 from these two points are similar:
Sampling: We sample a real data point x0, a random time step index n, and a noise vector z.
Trajectory Points: We need to obtain xtn and xtn+1. These points are estimated by taking one step of a numerical ODE solver (like first-order Euler or second-order Heun) starting from a point perturbed from x0. For example, using the Euler method for the PF ODE dtdx=21β(t)(∇xlogpt(x)+x):
xtn≈x0+α(tn)2/(1−α(tn)2)z (Approximation based on diffusion process properties)
xtn+1 is obtained by taking one step from xtn using the ODE solver. This requires estimating the score ∇xlogpt(x) at (xtn,tn). Crucially, CT often uses the score implied by the current model's estimate of x0 or related techniques, effectively bootstrapping the process.
Consistency Function fθ: This is the neural network being trained. It takes a noisy input xt and time t and predicts the corresponding x0.
Target Network fθ−: A slowly updated, exponential moving average (EMA) of the main network's parameters (θ) is used for the "target" prediction fθ−(xtn,tn). This stabilizes training, similar to techniques used in reinforcement learning or BYOL. The update rule is typically θ−←μθ−+(1−μ)θ, where μ is a momentum coefficient close to 1 (e.g., 0.999).
Distance Metric d(⋅,⋅): This measures the difference between the two predictions. Common choices include the L1 loss, L2 loss (Mean Squared Error), or perceptual losses like LPIPS.
Weighting Function λ(tn+1): This function weights the loss based on the time step. It often prioritizes matching at earlier time steps (closer to the data) or uses weighting schemes derived from diffusion model theory.
Estimating Trajectories and Scores
The most significant difference from distillation is how xtn and xtn+1 (points on the same trajectory) are obtained. Since there's no teacher model providing the score ∇xlogpt(x), CT must estimate it.
One common approach involves using the relationship between the score function and the conditional expectation E[x0∣xt]. If the model fθ(xt,t) estimates x0, it can be used to approximate the score needed for the ODE solver step that generates xtn+1 from xtn. This creates a self-supervised loop where the model refines its understanding of the trajectories and the consistency mapping simultaneously.
Training Algorithm Overview
The standalone consistency training process typically follows these steps in each iteration:
Sample Data: Draw a mini-batch of data points {x0(i)} from the true data distribution pdata.
Sample Timesteps: For each x0(i), sample a time index n(i)∼U(1,N−1). Let tn=schedule[n] and tn+1=schedule[n+1].
Generate Trajectory Pair:
Generate noise z(i)∼N(0,I).
Estimate xtn(i) (e.g., using x0(i) and z(i) based on the diffusion process definition).
Estimate the score at (xtn(i),tn), potentially using fθ or fθ−.
Use a numerical ODE solver (e.g., one step of Heun's method) with the estimated score to compute xtn+1(i) from xtn(i).
Target network prediction: yn(i)=fθ−(xtn(i),tn) (Stop gradient through target network).
Calculate Loss: Compute the consistency loss LCT using the distance metric d and weighting λ(tn+1):
Loss=B1i=1∑Bλ(tn+1(i))d(yn+1(i),yn(i))
where B is the batch size.
Gradient Update: Compute the gradient ∇θLCT and update the online network parameters θ using an optimizer (e.g., Adam).
Update Target Network: Update the target network parameters θ− using EMA: θ−←μθ−+(1−μ)θ.
Diagram illustrating the standalone consistency training loop for a single data point. The process involves sampling data and time, generating a pair of points along an estimated ODE trajectory using score estimates, computing outputs from the online and target networks, and updating the networks based on the consistency loss.
Architectural Considerations
The network architecture fθ(x,t) used in standalone CT is often similar to those employed in standard diffusion models or consistency distillation, such as U-Net variants or Transformers (like DiT). The key requirement is the ability to process a noisy input xt and a time embedding t to produce an estimate of x0. Adaptations might involve modifications to how time is embedded or how conditioning information (if any) is integrated.
Comparison to Distillation
Pros of Standalone Training:
Independence: Does not require a pre-trained diffusion model, saving the computational cost and time associated with training one first.
Potential for Higher Fidelity: Might avoid limitations inherent in the distillation process where the student model tries to mimic a potentially imperfect teacher.
Cons of Standalone Training:
Stability: Can be harder to stabilize than distillation, as it relies on bootstrapping the score estimation. Requires careful tuning of hyperparameters like the learning rate, EMA decay rate (μ), and loss weighting (λ(t)).
Convergence Speed: May require more training iterations to converge compared to distillation, which starts with guidance from a strong teacher model.
Standalone consistency training represents a significant step towards efficient generative modeling, enabling the creation of fast, few-step or even single-step generative models directly from data. While it presents unique training challenges compared to distillation, its independence from pre-trained models makes it an appealing and powerful technique in the generative modeling toolkit.