Now that we've discussed the U-Net architecture and the importance of incorporating timestep information, let's translate these ideas into practical code snippets. This section provides examples using PyTorch to illustrate how you might implement the core components: the U-Net structure itself and the sinusoidal timestep embeddings.
As covered previously, the network needs to know at which point in the diffusion process it's operating. Sinusoidal embeddings provide a way to encode the timestep t into a fixed-size vector that the network can easily use.
The formula for these embeddings involves sine and cosine functions applied to the timestep, scaled across different frequencies:
PE(t,2i)=sin(t/100002i/d) PE(t,2i+1)=cos(t/100002i/d)Here, t is the timestep, i indexes the embedding dimension, and d is the total embedding dimension. This technique allows the model to distinguish between different timesteps effectively.
Let's implement a function to generate these embeddings in PyTorch:
import torch
import torch.nn as nn
import math
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim # The dimension of the embeddings
def forward(self, time):
"""
Generates sinusoidal embeddings for a batch of timesteps.
Args:
time (torch.Tensor): A tensor of shape (batch_size,) containing timesteps.
Returns:
torch.Tensor: A tensor of shape (batch_size, dim) containing the embeddings.
"""
device = time.device
half_dim = self.dim // 2
# Calculate frequencies (exponents for 10000)
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
# Calculate arguments for sine and cosine
embeddings = time[:, None] * embeddings[None, :]
# Concatenate sine and cosine components
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
# Handle odd dimensions if necessary (though typically 'dim' is even)
if self.dim % 2 == 1:
embeddings = torch.cat([embeddings, torch.zeros_like(embeddings[:, :1])], dim=-1)
return embeddings
# Example Usage:
time_embedding_dim = 128 # Choose an embedding dimension
time_embed_module = SinusoidalPosEmb(time_embedding_dim)
# Simulate a batch of timesteps
batch_size = 16
timesteps = torch.randint(0, 1000, (batch_size,)) # Example: 1000 diffusion steps
# Generate embeddings
time_embeddings = time_embed_module(timesteps)
print(f"Timesteps shape: {timesteps.shape}")
print(f"Generated embeddings shape: {time_embeddings.shape}")
# Expected output:
# Timesteps shape: torch.Size([16])
# Generated embeddings shape: torch.Size([16, 128])
This SinusoidalPosEmb
module takes a batch of timesteps and produces the corresponding embeddings. These embeddings will later be projected and added into the U-Net layers.
The U-Net relies on reusable blocks for downsampling and upsampling. A typical block includes convolutions, normalization (like Group Normalization, which works well with varying batch sizes), activation functions (like SiLU or GeLU), and potentially residual connections. Timestep embeddings are often incorporated within these blocks.
Here's a simplified example of a residual block that includes timestep embedding integration:
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, time_emb_dim, groups=8):
super().__init__()
# Project timestep embeddings to match channel dimensions
self.time_mlp = nn.Sequential(
nn.SiLU(), # Swish activation
nn.Linear(time_emb_dim, out_channels)
)
# First convolutional layer
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.norm1 = nn.GroupNorm(groups, out_channels)
self.act1 = nn.SiLU()
# Second convolutional layer
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.norm2 = nn.GroupNorm(groups, out_channels)
self.act2 = nn.SiLU()
# Residual connection handling (if input/output channels differ)
if in_channels != out_channels:
self.res_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
else:
self.res_conv = nn.Identity() # No change needed
def forward(self, x, t_emb):
"""
Args:
x (torch.Tensor): Input feature map (batch, in_channels, height, width).
t_emb (torch.Tensor): Timestep embeddings (batch, time_emb_dim).
Returns:
torch.Tensor: Output feature map (batch, out_channels, height, width).
"""
h = self.conv1(x)
h = self.norm1(h)
h = self.act1(h)
# Incorporate timestep embedding
time_bias = self.time_mlp(t_emb)
# Reshape time_bias to (batch, out_channels, 1, 1) for broadcasting
h = h + time_bias[:, :, None, None]
h = self.conv2(h)
h = self.norm2(h)
h = self.act2(h)
# Add residual connection
return h + self.res_conv(x)
# Example Usage (requires a sample input and time embedding)
in_channels = 32
out_channels = 64
time_emb_dim = 128
img_size = 32 # Example image dimension
# Create dummy input tensor and time embeddings
x_sample = torch.randn(batch_size, in_channels, img_size, img_size)
t_emb_sample = torch.randn(batch_size, time_emb_dim) # Use pre-generated embeddings usually
# Instantiate and run the block
block = ResidualBlock(in_channels, out_channels, time_emb_dim)
output = block(x_sample, t_emb_sample)
print(f"Input shape: {x_sample.shape}")
print(f"Time embedding shape: {t_emb_sample.shape}")
print(f"Output shape: {output.shape}")
# Expected output:
# Input shape: torch.Size([16, 32, 32, 32])
# Time embedding shape: torch.Size([16, 128])
# Output shape: torch.Size([16, 64, 32, 32])
In this ResidualBlock
, the timestep embedding t_emb
is first passed through a small Multi-Layer Perceptron (time_mlp
) consisting of an activation (SiLU
) and a linear layer. This projects the embedding to have out_channels
dimensions. The result is then added as a bias to the feature map h
after the first convolution and normalization. The [:, :, None, None]
part reshapes the bias to allow broadcasting across the spatial dimensions (height and width). The block finishes with a second convolution, normalization, activation, and the addition of the residual connection (res_conv(x)
).
A full U-Net uses these blocks along with downsampling (e.g., nn.MaxPool2d
or strided convolution) and upsampling (e.g., nn.Upsample
or nn.ConvTranspose2d
) layers. Skip connections concatenate feature maps from the downsampling path to corresponding layers in the upsampling path.
Below is a high-level sketch of how these components fit together. Note that this is a simplified representation; actual U-Nets often have multiple blocks per resolution level and may include attention mechanisms.
Diagram illustrating the flow within a U-Net architecture used for noise prediction. Input image and timestep embedding are processed through downsampling blocks (encoder), a bottleneck, and upsampling blocks (decoder) with skip connections. Timestep information is typically injected into multiple blocks.
Building a full U-Net involves stacking these ResidualBlock
components, adding appropriate downsampling (e.g., nn.Conv2d
with stride=2
) and upsampling layers (e.g., nn.ConvTranspose2d
), and implementing the skip connections (often simple tensor concatenation along the channel dimension). The forward
method of the main U-Net class orchestrates this flow, passing the input image x
and the generated timestep embeddings t_emb
through the network structure.
Remember that this is a foundational setup. Real-world implementations often incorporate more sophisticated elements like attention layers (particularly at lower resolutions) and careful tuning of channel counts and block depths, which are essential for achieving high-quality generation results. However, understanding these basic building blocks provides a solid starting point for working with diffusion model architectures.
© 2025 ApX Machine Learning