Having established the theoretical underpinnings of Classifier-Free Guidance (CFG), we now turn to its practical implementation and the important process of tuning the guidance scale, often denoted as s or w. This scale dictates how strongly the generation process adheres to the provided conditioning signal (like a text prompt or class label).
CFG works by cleverly modifying the model's predicted noise (or predicted x0, depending on the parameterization) during the sampling process. At each timestep t, the model makes two predictions:
During training, this unconditional prediction is often achieved by randomly dropping the conditioning information for a fraction of the training examples (e.g., replacing text embeddings with a learned null token or class labels with a special 'unconditional' ID). This trains the model to predict noise both with and without guidance within the same architecture.
At inference time, both predictions are computed. The final noise prediction used for the denoising step is a linear combination, extrapolated away from the unconditional prediction in the direction of the conditional one:
ϵ^θ(xt,t,c)=ϵθ(xt,t,∅)+s⋅(ϵθ(xt,t,c)−ϵθ(xt,t,∅))Here, s is the guidance scale. Notice that if s=0, we recover the unconditional prediction, ϵ^θ=ϵθ(xt,t,∅). If s=1, we recover the standard conditional prediction, ϵ^θ=ϵθ(xt,t,c). Values of s>1 push the prediction further in the direction indicated by the condition, effectively amplifying the guidance.
This combined prediction ϵ^θ is then used within the chosen sampler (e.g., DDIM, DPM-Solver) to estimate xt−1 from xt.
The guidance scale s is a hyperparameter that controls the trade-off between sample quality (adherence to the condition) and diversity.
The impact of s is highly dependent on the specific model, dataset, and task. There's no single "best" value; it requires tuning.
Finding an effective guidance scale typically involves experimentation:
Relationship between guidance scale (s) and hypothetical metrics for prompt adherence, sample diversity, and the likelihood of artifacts. Increasing s generally improves adherence but reduces diversity and may introduce artifacts.
Below is a simplified illustration of how the CFG logic modifies a sampling loop. Assume model
can predict noise ϵθ, sampler
handles the denoising step, latents
is xt, t
is the timestep, cond_embedding
is the conditioning, and uncond_embedding
is the null conditioning.
import torch
# Assume model, sampler, latents, t, cond_embedding, uncond_embedding are defined
# Assume guidance_scale (s) is set, e.g., s = 7.5
# Concatenate inputs for batched inference
latent_model_input = torch.cat([latents] * 2)
time_input = torch.cat([t] * 2)
context_input = torch.cat([uncond_embedding, cond_embedding])
# Predict noise for both unconditional and conditional inputs
noise_pred_uncond, noise_pred_cond = model(latent_model_input, time_input, context=context_input).chunk(2)
# Combine predictions using the CFG formula
noise_pred_cfg = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
# Use the combined prediction in the sampler's step function
latents = sampler.step(noise_pred_cfg, t, latents)
# ... rest of the sampling loop ...
s
value might slightly change depending on the sampler used (e.g., DDIM vs. DPM-Solver++) and the number of sampling steps. Faster samplers using fewer steps might sometimes benefit from slightly higher guidance scales to maintain prompt fidelity.Mastering the implementation and tuning of the CFG scale is a fundamental skill for controlling modern diffusion models. It provides a powerful knob for balancing fidelity to the desired output against the inherent creativity and diversity of the generative process. Experimentation, guided by both visual feedback and relevant metrics, is essential for finding the sweet spot for your specific application.
© 2025 ApX Machine Learning