Alright, let's translate the theory of DDPM and DDIM sampling into practice. We've discussed how the reverse process works by iteratively denoising, starting from pure noise xT. Now, we'll implement the loops that perform this denoising step by step.
We assume you have a trained noise prediction model, often a U-Net, which we'll refer to as model
. This model takes a noisy input xt and the timestep t and outputs the predicted noise ϵθ(xt,t). We also assume you have precomputed the noise schedule variables: βt, αt=1−βt, and αˉt=∏i=1tαi. These are typically stored in tensors.
The DDPM sampling process follows the reverse Markov chain defined in Chapter 3. Starting with xT∼N(0,I), we iteratively sample xt−1 from pθ(xt−1∣xt) for t=T,T−1,...,1.
Recall the core equation for the denoising step:
xt−1=αt1(xt−1−αˉt1−αtϵθ(xt,t))+σtzwhere z∼N(0,I) is fresh noise added at each step, and σt2 is the variance of the reverse transition, often set to βt or β~t=1−αˉt1−αˉt−1βt. Using β~t is common.
Here's a Python-like implementation structure using PyTorch conventions:
import torch
def ddpm_sampler(model, n_steps, shape, device, betas, alphas, alphas_cumprod):
"""
Generates samples using the DDPM algorithm.
Args:
model: The trained noise prediction model (U-Net).
n_steps (int): Total number of diffusion steps (T).
shape: The shape of the desired output tensor (e.g., [batch_size, channels, height, width]).
device: The device to perform computations on (e.g., 'cuda' or 'cpu').
betas (torch.Tensor): Tensor of beta values for the noise schedule.
alphas (torch.Tensor): Tensor of alpha values (1 - beta).
alphas_cumprod (torch.Tensor): Tensor of cumulative products of alpha values.
Returns:
torch.Tensor: The generated samples (x_0).
"""
# 1. Start with random noise x_T
xt = torch.randn(shape, device=device)
# Precompute required schedule variables
sqrt_alphas = torch.sqrt(alphas).to(device)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod).to(device)
# Calculate posterior variance (using \tilde{\beta}_t)
alphas_cumprod_prev = torch.cat([torch.tensor([1.0], device=device), alphas_cumprod[:-1]])
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# Avoid division by zero at t=0, although we stop at t=1
posterior_variance[0] = betas[0] * (1.0 - 1.0) / (1.0 - alphas_cumprod[0]) if alphas_cumprod[0] != 1 else torch.tensor(0.0)
sqrt_posterior_variance = torch.sqrt(posterior_variance).to(device)
# 2. Iteratively denoise for t = T, T-1, ..., 1
for t in reversed(range(n_steps)):
# Prepare timestep tensor for the model
# Model usually expects shape [batch_size]
time_tensor = torch.full((shape[0],), t, dtype=torch.long, device=device)
# Predict the noise using the model
with torch.no_grad(): # No need to track gradients during sampling
predicted_noise = model(xt, time_tensor)
# Calculate the mean of the reverse distribution p(x_{t-1} | x_t)
alpha_t = alphas[t] # Get scalar alpha_t
sqrt_alpha_t = sqrt_alphas[t] # Get scalar sqrt_alpha_t
sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alphas_cumprod[t] # Get scalar term
# Equation for the mean: (1/sqrt(alpha_t)) * (xt - ( (1-alpha_t) / sqrt(1-alpha_bar_t) ) * eps_theta)
mean = (1 / sqrt_alpha_t) * (xt - ((1 - alpha_t) / sqrt_one_minus_alpha_cumprod_t) * predicted_noise)
# Get the variance term \sigma_t
variance = sqrt_posterior_variance[t]
# Sample z from N(0, I) if t > 0
z = torch.randn_like(xt) if t > 0 else torch.zeros_like(xt) # No noise added at the last step (t=0)
# Calculate x_{t-1}
xt = mean + variance * z # Note: variance already includes sqrt
# Optional: Clamp pixel values if generating images, e.g., xt.clamp_(-1., 1.)
# 3. Return the final denoised sample x_0
return xt
# --- Usage Example (assuming model and schedule variables are defined) ---
# T = 1000
# image_shape = [1, 3, 64, 64] # Batch size 1, 3 channels, 64x64 pixels
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# generated_image = ddpm_sampler(unet_model, T, image_shape, device, betas_tensor, alphas_tensor, alphas_cumprod_tensor)
Key aspects of this implementation:
xt
.n_steps - 1
down to 0
.model
to get predicted_noise
. Remember torch.no_grad()
for efficiency.z
and add it scaled by the variance
term (σt). Note that no noise is added at the very last step when predicting x0 from x1.xt
to the newly calculated x_{t-1}
for the next iteration.DDIM provides a faster sampling alternative by defining a non-Markovian generation process. It allows us to skip steps and uses a parameter η (eta) to control the trade-off between determinism (η=0) and stochasticity (η=1). When η=1, it approximates DDPM. When η=0, the process becomes deterministic given xt.
The DDIM update rule is:
xt−1=αˉt−1x^0+1−αˉt−1−σt2⋅ϵθ(xt,t)+σtzwhere x^0 is the predicted clean sample:
x^0=αˉtxt−1−αˉtϵθ(xt,t)and the standard deviation σt is controlled by η:
σt2=ηβ~t=η1−αˉt1−αˉt−1(1−αˉt−1αˉt)Notice that if η=0, then σt=0, and the last term vanishes, making the process deterministic. DDIM also allows using a subsequence of timesteps (e.g., skipping every 10 steps) for faster generation.
Here's a Python-like implementation structure for DDIM:
import torch
import numpy as np
def ddim_sampler(model, n_inference_steps, shape, device, alphas_cumprod, eta=0.0):
"""
Generates samples using the DDIM algorithm.
Args:
model: The trained noise prediction model (U-Net).
n_inference_steps (int): Number of denoising steps during inference (can be less than T).
shape: The shape of the desired output tensor.
device: The device to perform computations on.
alphas_cumprod (torch.Tensor): Tensor of cumulative products of alpha values from the full schedule (T steps).
eta (float): Controls the stochasticity (0.0 = deterministic, 1.0 = DDPM-like).
Returns:
torch.Tensor: The generated samples (x_0).
"""
# 1. Determine the timesteps to use for inference
n_train_steps = len(alphas_cumprod)
# Example: Use n_inference_steps equally spaced timesteps from [0, T-1]
# More sophisticated spacing strategies might exist.
inference_timesteps = np.linspace(0, n_train_steps - 1, n_inference_steps, dtype=int)
inference_timesteps_tensor = torch.from_numpy(inference_timesteps).long().to(device)
# 2. Start with random noise x_T (corresponding to the last timestep in our inference sequence)
xt = torch.randn(shape, device=device)
# Precompute required schedule variables on device
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod).to(device)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod).to(device)
# 3. Iteratively denoise using the inference timesteps
for i, t in enumerate(reversed(inference_timesteps_tensor)):
# Get the previous timestep index, handle boundary case t=0
t_prev_idx = max(0, len(inference_timesteps_tensor) - 1 - (i + 1))
t_prev = inference_timesteps_tensor[t_prev_idx] if i < len(inference_timesteps_tensor) - 1 else torch.tensor(-1, device=device) # -1 indicates the step to x_0
# Prepare timestep tensor for the model
time_tensor = torch.full((shape[0],), t, dtype=torch.long, device=device)
# Predict the noise using the model
with torch.no_grad():
predicted_noise = model(xt, time_tensor)
# Get alpha_cumprod terms for current and previous timesteps
alpha_cumprod_t = alphas_cumprod[t]
alpha_cumprod_t_prev = alphas_cumprod[t_prev] if t_prev >= 0 else torch.tensor(1.0, device=device) # alpha_cumprod for t=0 is 1
# Calculate the predicted original sample (x_0_hat)
# Formula: (xt - sqrt(1 - alpha_cumprod_t) * predicted_noise) / sqrt(alpha_cumprod_t)
sqrt_alpha_cumprod_t = sqrt_alphas_cumprod[t]
sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alphas_cumprod[t]
x0_pred = (xt - sqrt_one_minus_alpha_cumprod_t * predicted_noise) / sqrt_alpha_cumprod_t
# Calculate coefficients for xt-1
# Direction pointing to x_t
dir_xt = torch.sqrt(1.0 - alpha_cumprod_t_prev - (eta**2 * (1.0 - alpha_cumprod_t_prev) / (1.0 - alpha_cumprod_t) * (1.0 - alpha_cumprod_t / alpha_cumprod_t_prev))) * predicted_noise
# simplified: sqrt(1.0 - alpha_cumprod_t_prev - sigma_t^2) * predicted_noise
# Calculate sigma_t
variance = 0.0
if eta > 0:
# Calculate sigma_t^2 = eta^2 * \tilde{beta}_t
beta_t = 1.0 - (alpha_cumprod_t / alpha_cumprod_t_prev) if t_prev >= 0 and alpha_cumprod_t_prev != 0 else 0.0 # Approximated beta for the interval
term1 = (1.0 - alpha_cumprod_t_prev) / (1.0 - alpha_cumprod_t) if (1.0 - alpha_cumprod_t) != 0 else 0.0
variance = eta * torch.sqrt(term1 * beta_t)
# Sample random noise z
z = torch.randn_like(xt) if eta > 0 else torch.zeros_like(xt)
# Calculate x_{t-1}
# Formula: sqrt(alpha_cumprod_t_prev) * x0_pred + dir_xt + variance * z
sqrt_alpha_cumprod_t_prev = torch.sqrt(alpha_cumprod_t_prev)
xt = sqrt_alpha_cumprod_t_prev * x0_pred + dir_xt + variance * z
# Optional: Clamp
# xt.clamp_(-1., 1.)
# 4. Return the final sample x_0
return xt
# --- Usage Example ---
# inference_steps = 50 # Much fewer than T=1000
# eta_value = 0.0 # Deterministic sampling
# generated_image_ddim = ddim_sampler(unet_model, inference_steps, image_shape, device, alphas_cumprod_tensor, eta=eta_value)
Key differences in the DDIM implementation:
n_inference_steps
from the original n_train_steps
.eta
Parameter: Controls the amount of noise added. eta=0
makes the process deterministic.inference_timesteps
, potentially taking much larger jumps than the single steps in DDPM.Comparison of sampling step sequences for DDPM and DDIM. DDIM often uses fewer, larger steps.
These code structures provide a foundation. You'll need to integrate them with your specific model loading, data handling, and the precomputed noise schedule corresponding to your model's training. Experimenting with n_inference_steps
and eta
in DDIM allows you to explore the speed vs. quality trade-offs discussed previously.
© 2025 ApX Machine Learning