Now that we have explored the theory behind classifier-free guidance (CFG) in the previous sections, let's put it into practice. This hands-on exercise demonstrates how to modify a standard diffusion model sampling loop to incorporate CFG, enabling conditional generation without requiring a separate classifier model. We assume you have access to a pre-trained conditional diffusion model (like a U-Net or DiT) and a basic sampling function (e.g., implementing DDIM).
The core idea of CFG, as discussed, is to compute two predictions at each timestep t during the reverse process: one conditional prediction ϵθ(xt,c) using the guidance condition c (e.g., text embedding, class label), and one unconditional prediction ϵθ(xt,∅) using a null or empty condition ∅. These are then combined using a guidance scale s (often denoted as w in implementations) to steer the generation:
ϵ~θ(xt,c)=ϵθ(xt,∅)+s⋅(ϵθ(xt,c)−ϵθ(xt,∅))This adjusted noise estimate ϵ~θ(xt,c) is then used in the sampler's update step. A scale s=0 corresponds to unconditional generation, while s=1 uses only the conditional prediction (assuming the model was trained with conditioning dropout). Values s>1 amplify the guidance signal.
Let's assume you have a standard DDIM sampling function that looks something like this (simplified Python-like pseudocode):
def ddim_sample_loop(model, x_T, timesteps, condition, eta=0.0):
x_t = x_T
for t_idx, t in enumerate(timesteps):
# Get current and previous timestep
time_tensor = torch.tensor([t], device=x_t.device)
prev_t = timesteps[t_idx + 1] if t_idx < len(timesteps) - 1 else -1
# 1. Predict noise using the model
predicted_noise = model(x_t, time_tensor, condition) # Original prediction
# 2. Calculate x_0 prediction (using DDIM formula components like alpha_t)
# ... calculate pred_x0 ...
# 3. Calculate direction pointing to x_t
# ... calculate dir_xt ...
# 4. Calculate noise for stochasticity (if eta > 0)
# ... calculate sigma_t and random_noise ...
# 5. Compute the next sample x_{t-1}
x_prev = pred_x0 + dir_xt + sigma_t * random_noise
x_t = x_prev
return x_t
To implement CFG, we need to modify step 1, where the model predicts the noise:
noise_pred_cond = model(x_t, time_tensor, c)
noise_pred_uncond = model(x_t, time_tensor, null_condition)
guidance_scale = s
predicted_noise = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
predicted_noise
in the subsequent steps (2-5) of the original sampling loop.Here's how the modified part of the loop might look, assuming model
takes x_t
, time_tensor
, and cond
as input, and guidance_scale
(s) is provided:
def ddim_sample_loop_cfg(model, x_T, timesteps, condition, null_condition, guidance_scale, eta=0.0):
x_t = x_T
batch_size = x_t.shape[0] # Assuming x_T has shape [B, C, H, W]
for t_idx, t in enumerate(timesteps):
time_tensor = torch.full((batch_size,), t, device=x_t.device, dtype=torch.long)
prev_t = timesteps[t_idx + 1] if t_idx < len(timesteps) - 1 else -1
# 1. Predict noise using the model with CFG
# Efficiently predict both conditional and unconditional noise
# Requires model to handle batched conditions
model_input = torch.cat([x_t] * 2) # Duplicate input for cond/uncond
time_input = torch.cat([time_tensor] * 2)
condition_input = torch.cat([condition, null_condition]) # Batch conditions
# Single model call for efficiency
noise_pred_combined = model(model_input, time_input, condition_input)
# Split the predictions
noise_pred_cond, noise_pred_uncond = noise_pred_combined.chunk(2)
# Apply CFG formula
predicted_noise = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
# --- Remainder of DDIM step (Steps 2-5) ---
# Calculate alpha_t, alpha_t_prev, sigma_t based on t and prev_t
# (Assuming these values are precomputed or calculated from a noise schedule)
alpha_t = get_alpha(t)
alpha_t_prev = get_alpha(prev_t) if prev_t >= 0 else 1.0
sigma_t = eta * torch.sqrt((1 - alpha_t_prev) / (1 - alpha_t) * (1 - alpha_t / alpha_t_prev))
# 2. Calculate x_0 prediction
pred_x0 = (x_t - torch.sqrt(1.0 - alpha_t) * predicted_noise) / torch.sqrt(alpha_t)
# Optional: Clamp predicted x0 to [-1, 1] or other valid range
# pred_x0 = torch.clamp(pred_x0, -1.0, 1.0)
# 3. Calculate direction pointing to x_t
dir_xt = torch.sqrt(1.0 - alpha_t_prev - sigma_t**2) * predicted_noise
# 4. Calculate noise for stochasticity (if eta > 0)
random_noise = torch.randn_like(x_t) if eta > 0 and prev_t >= 0 else torch.zeros_like(x_t)
# 5. Compute the next sample x_{t-1}
x_prev = torch.sqrt(alpha_t_prev) * pred_x0 + dir_xt + sigma_t * random_noise
x_t = x_prev
# --- End of DDIM step ---
return x_t # Return the final generated sample(s) x_0
Note: The get_alpha(t)
function retrieves the cumulative product of variances αˉt from the noise schedule for timestep t. The exact implementation depends on how your noise schedule is defined.
The guidance_scale
(s) is a hyperparameter that controls the trade-off between sample quality/diversity and adherence to the condition c.
Experimentation is needed to find the optimal s for your specific model, dataset, and task. You might generate samples with varying s values and evaluate them qualitatively or using appropriate metrics.
Subjective illustration of how increasing the guidance scale
s
typically improves condition adherence (e.g., matching a text prompt) while potentially decreasing overall sample diversity or introducing artifacts at very high values. Optimal values balance these aspects.
By implementing CFG, you gain a powerful method for controlling conditional generation in diffusion models without relying on external classifiers, directly leveraging the learned representations within the generative model itself. Experimenting with the guidance scale is an important part of achieving the desired output for your specific application.
© 2025 ApX Machine Learning