Classifier-free guidance (CFG) enables conditional generation without requiring a separate classifier model. This hands-on exercise demonstrates how to modify a standard diffusion model sampling loop to incorporate CFG. We assume you have access to a pre-trained conditional diffusion model (like a U-Net or DiT) and a basic sampling function (e.g., implementing DDIM).The core idea of CFG, as discussed, is to compute two predictions at each timestep $t$ during the reverse process: one conditional prediction $\epsilon_\theta(x_t, c)$ using the guidance condition $c$ (e.g., text embedding, class label), and one unconditional prediction $\epsilon_\theta(x_t, \emptyset)$ using a null or empty condition $\emptyset$. These are then combined using a guidance scale $s$ (often denoted as $w$ in implementations) to steer the generation:$$ \tilde{\epsilon}\theta(x_t, c) = \epsilon\theta(x_t, \emptyset) + s \cdot (\epsilon_\theta(x_t, c) - \epsilon_\theta(x_t, \emptyset)) $$This adjusted noise estimate $\tilde{\epsilon}_\theta(x_t, c)$ is then used in the sampler's update step. A scale $s=0$ corresponds to unconditional generation, while $s=1$ uses only the conditional prediction (assuming the model was trained with conditioning dropout). Values $s > 1$ amplify the guidance signal.PrerequisitesA pre-trained conditional diffusion model capable of accepting conditioning information $c$ and a representation for the null condition $\emptyset$. Models trained with conditioning dropout (randomly setting $c=\emptyset$ during training) are suitable.A function implementing a diffusion sampler (e.g., DDIM).Conditioning information $c$ (e.g., text embeddings) and a corresponding null condition tensor $\emptyset$.Modifying the Sampling LoopLet's assume you have a standard DDIM sampling function that looks something like this (simplified Python-like pseudocode):def ddim_sample_loop(model, x_T, timesteps, condition, eta=0.0): x_t = x_T for t_idx, t in enumerate(timesteps): # Get current and previous timestep time_tensor = torch.tensor([t], device=x_t.device) prev_t = timesteps[t_idx + 1] if t_idx < len(timesteps) - 1 else -1 # 1. Predict noise using the model predicted_noise = model(x_t, time_tensor, condition) # Original prediction # 2. Calculate x_0 prediction (using DDIM formula components like alpha_t) # ... calculate pred_x0 ... # 3. Calculate direction pointing to x_t # ... calculate dir_xt ... # 4. Calculate noise for stochasticity (if eta > 0) # ... calculate sigma_t and random_noise ... # 5. Compute the next sample x_{t-1} x_prev = pred_x0 + dir_xt + sigma_t * random_noise x_t = x_prev return x_tTo implement CFG, we need to modify step 1, where the model predicts the noise:Prepare Inputs: Ensure you have both the target condition $c$ and the null condition $\emptyset$. The null condition is often a tensor of zeros with the same shape as $c$, or a specific embedding learned for the unconditional case.Batch Inputs (Optional but Efficient): If processing multiple samples or if the model supports batching, you can often concatenate the conditional and unconditional inputs along the batch dimension to perform both predictions in a single model forward pass. For a single sample, this means creating a batch of size 2.Perform Predictions: Call the model twice (or once with a batch):noise_pred_cond = model(x_t, time_tensor, c)noise_pred_uncond = model(x_t, time_tensor, null_condition)Combine Predictions: Apply the CFG formula using the guidance scale $s$:guidance_scale = spredicted_noise = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)Use Combined Prediction: Use this predicted_noise in the subsequent steps (2-5) of the original sampling loop.Implementation Example (PyTorch-like)Here's how the modified part of the loop might look, assuming model takes x_t, time_tensor, and cond as input, and guidance_scale ($s$) is provided:def ddim_sample_loop_cfg(model, x_T, timesteps, condition, null_condition, guidance_scale, eta=0.0): x_t = x_T batch_size = x_t.shape[0] # Assuming x_T has shape [B, C, H, W] for t_idx, t in enumerate(timesteps): time_tensor = torch.full((batch_size,), t, device=x_t.device, dtype=torch.long) prev_t = timesteps[t_idx + 1] if t_idx < len(timesteps) - 1 else -1 # 1. Predict noise using the model with CFG # Efficiently predict both conditional and unconditional noise # Requires model to handle batched conditions model_input = torch.cat([x_t] * 2) # Duplicate input for cond/uncond time_input = torch.cat([time_tensor] * 2) condition_input = torch.cat([condition, null_condition]) # Batch conditions # Single model call for efficiency noise_pred_combined = model(model_input, time_input, condition_input) # Split the predictions noise_pred_cond, noise_pred_uncond = noise_pred_combined.chunk(2) # Apply CFG formula predicted_noise = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) # --- Remainder of DDIM step (Steps 2-5) --- # Calculate alpha_t, alpha_t_prev, sigma_t based on t and prev_t # (Assuming these values are precomputed or calculated from a noise schedule) alpha_t = get_alpha(t) alpha_t_prev = get_alpha(prev_t) if prev_t >= 0 else 1.0 sigma_t = eta * torch.sqrt((1 - alpha_t_prev) / (1 - alpha_t) * (1 - alpha_t / alpha_t_prev)) # 2. Calculate x_0 prediction pred_x0 = (x_t - torch.sqrt(1.0 - alpha_t) * predicted_noise) / torch.sqrt(alpha_t) # Optional: Clamp predicted x0 to [-1, 1] or other valid range # pred_x0 = torch.clamp(pred_x0, -1.0, 1.0) # 3. Calculate direction pointing to x_t dir_xt = torch.sqrt(1.0 - alpha_t_prev - sigma_t**2) * predicted_noise # 4. Calculate noise for stochasticity (if eta > 0) random_noise = torch.randn_like(x_t) if eta > 0 and prev_t >= 0 else torch.zeros_like(x_t) # 5. Compute the next sample x_{t-1} x_prev = torch.sqrt(alpha_t_prev) * pred_x0 + dir_xt + sigma_t * random_noise x_t = x_prev # --- End of DDIM step --- return x_t # Return the final generated sample(s) x_0Note: The get_alpha(t) function retrieves the cumulative product of variances $\bar{\alpha}_t$ from the noise schedule for timestep $t$. The exact implementation depends on how your noise schedule is defined.Tuning the Guidance Scale ($s$)The guidance_scale ($s$) is a hyperparameter that controls the trade-off between sample quality/diversity and adherence to the condition $c$.Low $s$ (e.g., 1.0 - 3.0): Samples tend to be more diverse and potentially higher fidelity in terms of pure image quality, but might follow the condition $c$ less strictly. $s=0$ is purely unconditional.Moderate $s$ (e.g., 4.0 - 8.0): Often provides a good balance. Samples adhere well to the condition without sacrificing too much quality. This is a common range for many text-to-image models.High $s$ (e.g., 9.0 - 15.0+): Samples strongly follow the condition $c$. However, high values can sometimes lead to saturation, artifacts, or reduced diversity, as the model is pushed strongly towards the conditional prediction.Experimentation is needed to find the optimal $s$ for your specific model, dataset, and task. You might generate samples with varying $s$ values and evaluate them qualitatively or using appropriate metrics.{"data": [{"x": [1, 2, 3, 5, 7, 10, 15], "y": [5, 6, 7, 8.5, 9, 8, 7], "name": "Condition Adherence", "type": "scatter", "mode": "lines+markers", "marker": {"color": "#4263eb"}, "line": {"color": "#4263eb"}}, {"x": [1, 2, 3, 5, 7, 10, 15], "y": [8.5, 8.5, 8, 7.5, 7, 6, 5], "name": "Sample Diversity/Quality", "type": "scatter", "mode": "lines+markers", "marker": {"color": "#12b886"}, "line": {"color": "#12b886"}}], "layout": {"title": {"text": "Illustrative Trade-off with Guidance Scale (s)"}, "xaxis": {"title": {"text": "Guidance Scale (s)"}}, "yaxis": {"title": {"text": "Subjective Score (Higher is Better)"}, "range": [0, 10]}, "legend": {"yanchor": "top", "y": 0.99, "xanchor": "left", "x": 0.01}, "margin": {"l": 50, "r": 20, "t": 50, "b": 40}}}Subjective illustration of how increasing the guidance scale s typically improves condition adherence (e.g., matching a text prompt) while potentially decreasing overall sample diversity or introducing artifacts at very high values. Optimal values balance these aspects.Further NotesComputational Cost: CFG roughly doubles the computational cost per sampling step compared to standard conditional sampling because it requires two model forward passes (or one pass through a doubled batch).Model Training: This technique works best when the model has been trained with conditioning dropout, allowing it to learn both conditional and unconditional generation paths effectively.Sampler Interaction: The CFG adjustment is typically applied to the noise prediction ($\epsilon$) before it's used by the sampler (DDIM, DPM-Solver, etc.). The rest of the sampler logic usually remains unchanged.By implementing CFG, you gain a powerful method for controlling conditional generation in diffusion models without relying on external classifiers, directly leveraging the learned representations within the generative model itself. Experimenting with the guidance scale is an important part of achieving the desired output for your specific application.