Now that we understand the principle behind Classifier-Free Guidance (CFG), let's put it into practice. This section provides practical examples and code structure for implementing CFG during the sampling process of your diffusion model. We'll assume you have a pre-trained U-Net model capable of accepting conditioning information, as discussed in the "Implementing Classifier-Free Guidance" section.
Remember, CFG guides the generation towards a condition y (like a class label or text embedding) without needing a separate classifier model. It achieves this by leveraging the diffusion model's ability to perform both conditional and unconditional predictions. During sampling at each timestep t, we calculate:
These two predictions are then combined using a guidance scale w:
ϵ^t=ϵθ(xt,t,∅)+w⋅(ϵθ(xt,t,y)−ϵθ(xt,t,∅))This ϵ^t is the guided noise estimate used to compute the next, less noisy state xt−1. The guidance scale w controls how strongly the generation adheres to the condition y. A value of w=0 ignores the condition, resulting in unconditional sampling. Increasing w pushes the generation more strongly towards the condition y.
Before starting the sampling loop, you need to prepare your conditioning input y and the null condition ∅.
Let's assume y_cond
holds the conditioning vector for your desired output (e.g., embedding for "cat" or class 7) and y_null
holds the vector for the null condition.
The core modification happens inside the sampling loop (whether using DDPM or DDIM). Here's a simplified Python-like pseudo-code structure, assuming a PyTorch-like framework and a function get_denoised_xt_minus_1
that performs the standard reverse step (like Eq. 11 or 12 from the DDPM paper, or the DDIM update) given xt, t, and a predicted noise ϵ:
# Assume model is your U-Net, scheduler holds noise schedule info
# x_t starts as pure noise: x_T ~ N(0, I)
# timesteps is a list/tensor of timesteps, e.g., [999, 998, ..., 0]
# y_cond is the conditioning vector for the desired output
# y_null is the null conditioning vector
# w is the guidance scale (e.g., 7.5)
x_t = torch.randn_like(initial_sample_shape) # Start with random noise x_T
for t_val in timesteps:
t_tensor = torch.tensor([t_val] * batch_size, device=x_t.device)
# Ensure x_t requires gradients for model input if needed by framework details,
# but we typically don't need gradients during inference.
# Using torch.no_grad() is common practice for efficiency.
with torch.no_grad():
# 1. Predict noise for the conditional input
pred_noise_cond = model(x_t, t_tensor, y_cond)
# 2. Predict noise for the unconditional input
pred_noise_uncond = model(x_t, t_tensor, y_null)
# 3. Combine predictions using the CFG formula
guided_noise = pred_noise_uncond + w * (pred_noise_cond - pred_noise_uncond)
# 4. Use the guided noise to compute x_{t-1}
# This step depends on whether you use DDPM or DDIM sampling logic
# Example assuming a function encapsulating the reverse step:
x_t = scheduler.step(guided_noise, t_val, x_t) # Updates x_t to x_{t-1}
# Final result after the loop is x_0 (the generated sample)
generated_sample = x_t
Key Steps in the Loop:
t
.y_cond
to the model.y_null
to the model.guided_noise
using the unconditional prediction, the conditional prediction, and the guidance scale w
.guided_noise
in your chosen sampler's (DDPM or DDIM) reverse diffusion equation to calculate xt−1. Update x_t
for the next iteration.Repeat this for all timesteps from T−1 down to 0. The final x_t
will be your generated sample x0.
The choice of w
significantly impacts the output.
w
(e.g., 0 or 1): Generation is less constrained by the condition. If w=0, it's purely unconditional. If w=1, it follows the learned conditional distribution but might lack strong adherence. Samples might be diverse but less aligned with the prompt y
.w
(e.g., 3 to 10): Often the sweet spot. Balances adherence to the condition y
with overall sample quality and diversity. The generated image clearly reflects the condition.w
(e.g., 15+): Strong adherence to the condition, but samples might become less diverse, potentially exhibiting saturation or artifacts. The model might over-emphasize features related to the condition.Experimenting with different values of w
is common to find the best trade-off for a specific model and task.
Let's visualize how changing w
might affect generating, say, the digit '8' using a diffusion model trained on MNIST with CFG.
This illustrative plot shows the typical trade-off. As guidance scale
w
increases, adherence to the condition (generating an '8') generally improves (blue line), but sample diversity and potentially overall quality might decrease after a certain point (orange line), sometimes leading to artifacts at very high values. The green shaded region indicates a common range where a good balance is often found.
It's important to remember that CFG during sampling relies on the model being specifically trained to handle both conditional and unconditional inputs. This is typically achieved using conditioning dropout during the training phase:
This forces the model to learn how to predict noise both when guided by a specific condition and when no condition is provided (using the null embedding). Without this training strategy, the model wouldn't know how to interpret the null condition ∅, and the CFG formula wouldn't produce meaningful guidance.
By implementing the guided sampling loop described here, leveraging a model trained with conditioning dropout, you can effectively steer the diffusion process to generate outputs that match your desired conditions. This significantly expands the creative control offered by diffusion models.
© 2025 ApX Machine Learning