Having reviewed the standard noise schedules and their limitations, let's put theory into practice. In this hands-on section, we will implement several noise schedule variants using PyTorch. Understanding how to generate and manipulate these schedules is fundamental for controlling the diffusion process and tailoring it to specific data or generation requirements.
We'll implement the common linear and cosine schedules, alongside a simple custom example, and visualize their behavior. This practical exercise will solidify your understanding of how the variance βt changes over time t and how this impacts the cumulative signal retention, represented by αˉt.
First, let's import PyTorch and define some helper functions that are common across all schedule types. We'll need functions to calculate the αt=1−βt values and the cumulative products αˉt=∏i=1tαi.
import torch
import torch.nn.functional as F
import math
# Number of diffusion timesteps
T = 1000
def get_alphas_and_cumprod(betas):
"""Calculates alpha_t and alpha_t_cumprod from betas."""
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
return alphas, alphas_cumprod
# Example usage placeholder (will be replaced by specific schedule betas)
# betas_example = torch.linspace(0.0001, 0.02, T)
# alphas_example, alphas_cumprod_example = get_alphas_and_cumprod(betas_example)
# print(f"Example Alphas shape: {alphas_example.shape}")
# print(f"Example Alphas Cumprod shape: {alphas_cumprod_example.shape}")
These helper functions take a tensor of βt values and return the corresponding αt and αˉt tensors, which are essential components in both the forward and reverse diffusion processes.
The linear schedule is perhaps the most straightforward. The variance βt increases linearly from a starting value βstart to an ending value βend over T timesteps.
def linear_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02):
"""Generates a linear schedule for beta_t."""
return torch.linspace(beta_start, beta_end, timesteps)
# Generate the linear schedule
betas_linear = linear_beta_schedule(T)
alphas_linear, alphas_cumprod_linear = get_alphas_and_cumprod(betas_linear)
# print(f"Linear Betas (first 5): {betas_linear[:5]}")
# print(f"Linear Alphas Cumprod (first 5): {alphas_cumprod_linear[:5]}")
# print(f"Linear Alphas Cumprod (last 5): {alphas_cumprod_linear[-5:]}")
This schedule adds noise relatively slowly at the beginning and speeds up linearly towards the end.
The cosine schedule, proposed by Nichol and Dhariwal (2021), aims to prevent the signal from decaying too quickly early in the process, which can be beneficial for image quality. It defines αˉt based on a cosine function and then derives βt from it.
def cosine_beta_schedule(timesteps, s=0.008):
"""
Generates a cosine schedule for beta_t, based on alpha_t_cumprod.
Proposed in: https://arxiv.org/abs/2102.09672
"""
steps = timesteps + 1
t = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((t / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0] # Normalize
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0.0001, 0.9999) # Clip to prevent numerical issues
# Generate the cosine schedule
betas_cosine = cosine_beta_schedule(T)
alphas_cosine, alphas_cumprod_cosine = get_alphas_and_cumprod(betas_cosine)
# print(f"Cosine Betas (first 5): {betas_cosine[:5]}")
# print(f"Cosine Alphas Cumprod (first 5): {alphas_cumprod_cosine[:5]}")
# print(f"Cosine Alphas Cumprod (last 5): {alphas_cumprod_cosine[-5:]}")
Note the small offset s
used to prevent βt from being too small near t=0. The clipping ensures numerical stability. Observe how the αˉt values decrease more slowly initially compared to the linear schedule.
To illustrate the flexibility, let's implement a quadratic schedule where βt increases quadratically. This is just one example; you could design schedules based on sigmoid functions, exponential functions, or any other monotonically increasing function depending on the desired noise injection profile.
def quadratic_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02):
"""Generates a quadratic schedule for beta_t."""
betas_quad = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2
return betas_quad
# Generate the quadratic schedule
betas_quadratic = quadratic_beta_schedule(T)
alphas_quadratic, alphas_cumprod_quadratic = get_alphas_and_cumprod(betas_quadratic)
# print(f"Quadratic Betas (first 5): {betas_quadratic[:5]}")
# print(f"Quadratic Alphas Cumprod (first 5): {alphas_cumprod_quadratic[:5]}")
# print(f"Quadratic Alphas Cumprod (last 5): {alphas_cumprod_quadratic[-5:]}")
This schedule adds noise even more slowly at the beginning than the linear schedule but accelerates faster towards the end.
Comparing these schedules visually helps understand their different characteristics. Let's plot both the βt values (variance added at step t) and the αˉt values (cumulative signal remaining at step t).
Comparison of βt values for linear, cosine, and quadratic schedules over T=1000 timesteps.
Comparison of αˉt values for linear, cosine, and quadratic schedules over T=1000 timesteps.
The plots clearly show the differences:
Within a typical diffusion model implementation (often a PyTorch nn.Module
), you would pre-compute these schedule-related tensors and register them as buffers. These buffers can then be indexed by the timestep t during both training (for the forward process q(xt∣x0)) and sampling (for the reverse process pθ(xt−1∣xt)).
Here's a simplified sketch:
import torch
import torch.nn as nn
class SimpleDiffusionModel(nn.Module):
def __init__(self, schedule_type='linear', timesteps=1000, beta_start=0.0001, beta_end=0.02):
super().__init__()
self.timesteps = timesteps
if schedule_type == 'linear':
betas = linear_beta_schedule(timesteps, beta_start, beta_end)
elif schedule_type == 'cosine':
betas = cosine_beta_schedule(timesteps)
elif schedule_type == 'quadratic':
betas = quadratic_beta_schedule(timesteps, beta_start, beta_end)
else:
raise ValueError(f"Unknown schedule: {schedule_type}")
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) # Add alpha_cumprod_0 = 1
# Register buffers
self.register_buffer('betas', betas)
self.register_buffer('alphas_cumprod', alphas_cumprod)
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# Other derived quantities used in forward/reverse process
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
# ... potentially others for DDPM/DDIM sampling ...
def forward_process(self, x_0, t, noise=None):
"""Applies the forward diffusion process q(x_t | x_0)."""
if noise is None:
noise = torch.randn_like(x_0)
# Extract values for the given batch of timesteps t
sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t, None, None, None] # Match shape (B, C, H, W)
sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t, None, None, None]
x_t = sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise
return x_t
# ... __call__ method for prediction (e.g., noise prediction) ...
# ... sampling methods (DDPM, DDIM) using the buffers ...
# Example instantiation
# model = SimpleDiffusionModel(schedule_type='cosine')
# print(model.betas.shape)
# print(model.sqrt_alphas_cumprod.shape)
This example shows how the chosen schedule directly populates the necessary tensors within the model structure. The specific values derived (like sqrt_alphas_cumprod
) depend directly on the βt values generated by the selected schedule function.
By implementing and visualizing these schedules, you've gained practical insight into how they shape the diffusion process. Remember that the choice of schedule is a design decision that can impact training dynamics and final sample quality. Learned schedules, which we discussed earlier, represent a further step where the model optimizes these variance steps itself, but understanding these fixed, analytical schedules provides the necessary foundation. As we move forward to more complex architectures and training techniques, the noise schedule remains a fundamental component influencing the model's behavior.
© 2025 ApX Machine Learning