While fixed noise schedules, whether linear, cosine, or custom-designed, provide a predefined path for the diffusion process, they operate under the assumption that a single, predetermined variance trajectory for the reverse steps is optimal. However, the ideal variance σt2 in the reverse process step pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),σt2I) might vary depending on the timestep t or even the specific data instance. Fixing the variance, typically to choices derived from the forward process noise schedule like βt or β~t=1−αˉt1−αˉt−1βt, limits the model's flexibility.
This limitation motivates the development of learned variance schedules, a technique where the diffusion model itself predicts the appropriate variance for each reverse step during training. This approach, notably explored by Nichol and Dhariwal in "Improved Denoising Diffusion Probabilistic Models" (2021), allows the model to dynamically adjust the stochasticity of the reverse process.
Why Learn the Variance?
Learning the variance grants the model more expressive power. Consider these points:
- Optimal Noise Levels: Different stages of the reverse process might benefit from different amounts of noise. Early steps (large t) might need larger variance to make significant changes from pure noise, while later steps (small t) might need smaller variance for fine detail refinement. A learned variance allows the model to tailor this.
- Improved Likelihood: The original DDPM paper fixed the reverse process variance primarily for simplicity, optimizing a surrogate objective related to noise prediction (Lsimple). Learning the variance allows for direct optimization related to the variational lower bound (VLB) on the data log-likelihood, potentially leading to models that better capture the true data distribution and achieve superior log-likelihood scores.
- Data-Dependent Adaptation: While typically learned as a function of timestep t, the mechanism could potentially adapt based on xt as well, although this is less common.
Parameterizing and Predicting the Variance
Instead of fixing σt2, we parameterize it and have the model predict the parameters. A common approach involves interpolating between the two standard fixed choices, βt and β~t. Recall that βt corresponds to the forward process variance at step t, and β~t is derived to match the posterior variance q(xt−1∣xt,x0) when x0 is known.
The learned variance σθ,t2 can be parameterized as:
σθ,t2=exp(vlogβt+(1−v)logβ~t)
Here, v is a parameter predicted by the neural network, typically constrained between 0 and 1. The network architecture (e.g., U-Net or Transformer) is modified to output an additional value (or set of values, one per pixel/patch if spatially varying variance is desired, though usually a single scalar v per timestep is predicted) representing v, alongside the prediction used for the mean μθ(xt,t) (which is typically derived from the noise prediction ϵθ(xt,t)).
The diffusion model takes the noisy input xt and the timestep t embedding. Its output is split to predict both the noise ϵθ (determining the reverse process mean) and the variance parameter vθ.
Adjusting the Training Objective
When learning the variance, the training objective needs to account for this prediction. The original DDPM Lsimple objective only focuses on predicting the noise ϵ. To train the variance prediction v, the loss function incorporates a term derived from the VLB, often denoted Lvlb. This term directly involves the predicted variance σθ,t2.
A common practice is to use a hybrid objective:
Lhybrid=Lsimple+λLvlb
where Lsimple=Et,x0,ϵ[∣∣ϵ−ϵθ(xt,t)∣∣2] is the standard noise prediction loss, and Lvlb is the term encouraging accurate variance prediction. The hyperparameter λ balances the two objectives. Setting λ=0 recovers the standard DDPM training with fixed variance. Nichol and Dhariwal found that a small, non-zero λ (e.g., λ=0.001) worked well, preserving the sample quality benefits of Lsimple while gaining the likelihood improvements from Lvlb.
Implementation Considerations
Implementing learned variance involves these primary modifications:
- Model Output: Adjust the final layer of your network (U-Net or Transformer) to output twice the number of channels compared to standard ϵ-prediction. One half represents ϵθ, and the other half represents the parameters needed to compute σθ,t2 (e.g., the value v to interpolate between βt and β~t).
- Loss Function: Implement the hybrid loss Lhybrid, calculating both the MSE loss on ϵθ and the Lvlb term based on the predicted variance.
- Sampling: During sampling, calculate the reverse step xt−1 using the standard mean calculation derived from the predicted ϵθ, but draw the noise component z∼N(0,I) scaled by the predicted standard deviation σθ,t instead of a fixed σt.
xt−1=μθ(xt,t)+σθ,tzwhere z∼N(0,I) if t>1, else z=0
Benefits and Trade-offs
- Benefits:
- Can significantly improve log-likelihood scores compared to fixed variance models.
- May lead to modest improvements in sample quality (e.g., FID scores), although the primary gain is often in likelihood.
- Provides the model with greater flexibility to adapt the generation process.
- Trade-offs:
- Increases model complexity as the network must predict additional parameters.
- The hybrid loss function adds complexity to the training setup.
- Requires careful tuning of the hyperparameter λ in the hybrid loss.
Learned variance schedules represent a step beyond fixed or manually designed schedules, empowering the diffusion model to optimize a fundamental aspect of the reverse process. While adding some complexity, the potential gains in likelihood and adaptability make it a valuable technique in the advanced diffusion modeling toolkit, particularly when accurate density estimation is as important as sample quality.