Let's put the theory into practice by modifying a standard U-Net block to include attention mechanisms. As discussed earlier, adding attention, particularly self-attention, allows the model to capture long-range dependencies within the image features, which can be beneficial for generating coherent structures. We'll focus on adding a spatial self-attention layer within a residual block commonly found in U-Net architectures used for diffusion.
We assume you have a working understanding of PyTorch and the basic components of a U-Net (convolutional layers, normalization, activation functions, residual connections). Our goal is to augment an existing block structure.
Let's start with a typical residual block often used in U-Nets. It might look something like this (simplified):
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, time_emb_dim=None):
super().__init__()
self.norm1 = nn.GroupNorm(32, in_channels)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
# Optional time embedding projection
self.time_mlp = None
if time_emb_dim is not None:
self.time_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, out_channels * 2) # For scale and shift
)
self.norm2 = nn.GroupNorm(32, out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.activation = nn.SiLU() # Swish activation is common
# Skip connection matching
if in_channels != out_channels:
self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1)
else:
self.skip_connection = nn.Identity()
def forward(self, x, t_emb=None):
residual = x
h = self.norm1(x)
h = self.activation(h)
h = self.conv1(h)
if self.time_mlp is not None and t_emb is not None:
time_encoding = self.time_mlp(t_emb)
time_encoding = time_encoding.view(h.shape[0], h.shape[1] * 2, 1, 1)
scale, shift = torch.chunk(time_encoding, 2, dim=1)
h = self.norm2(h) * (1 + scale) + shift # Modulate features
else:
h = self.norm2(h) # Apply norm if no time embedding modulation
h = self.activation(h)
h = self.conv2(h)
return h + self.skip_connection(residual)
This ResidualBlock
includes Group Normalization, convolutional layers, an activation function (SiLU), and handles residual connections, along with optional time embedding modulation.
Now, let's define a self-attention block suitable for image features. Since standard multi-head self-attention (MHSA) operates on sequences, we need to reshape the 2D spatial feature maps into sequences.
import math
class AttentionBlock(nn.Module):
def __init__(self, channels, num_heads=8):
super().__init__()
assert channels % num_heads == 0, f"Channels ({channels}) must be divisible by num_heads ({num_heads})"
self.num_heads = num_heads
self.head_dim = channels // num_heads
self.scale = 1 / math.sqrt(self.head_dim)
self.norm = nn.GroupNorm(32, channels)
# Query, Key, Value projections
self.to_qkv = nn.Conv2d(channels, channels * 3, kernel_size=1)
# Output projection
self.to_out = nn.Conv2d(channels, channels, kernel_size=1)
def forward(self, x):
b, c, h, w = x.shape
residual = x
qkv = self.to_qkv(self.norm(x))
# Reshape for attention: (b, c, h, w) -> (b, c, h*w) -> (b, num_heads, head_dim, h*w) -> (b*num_heads, head_dim, h*w)
q, k, v = qkv.chunk(3, dim=1)
q = q.reshape(b * self.num_heads, self.head_dim, h * w)
k = k.reshape(b * self.num_heads, self.head_dim, h * w)
v = v.reshape(b * self.num_heads, self.head_dim, h * w)
# Scaled dot-product attention
# (b*num_heads, head_dim, h*w) @ (b*num_heads, h*w, head_dim) -> (b*num_heads, h*w, h*w)
attention_scores = torch.bmm(q.transpose(1, 2), k) * self.scale
attention_probs = F.softmax(attention_scores, dim=-1)
# (b*num_heads, h*w, h*w) @ (b*num_heads, head_dim, h*w)' -> (b*num_heads, h*w, head_dim)
out = torch.bmm(attention_probs, v.transpose(1, 2))
# Reshape back: (b*num_heads, h*w, head_dim) -> (b, num_heads, h*w, head_dim) -> (b, c, h, w)
out = out.transpose(1, 2).reshape(b, c, h, w)
out = self.to_out(out)
return out + residual # Add residual connection
Key points in AttentionBlock
:
GroupNorm
before attention, a common practice.to_qkv
projects the input into Query, Key, and Value using a 1x1 convolution.(h, w)
are flattened into a sequence length h*w
.(b, c, h, w)
.Attention blocks are typically inserted after residual blocks, especially at lower spatial resolutions (deeper in the U-Net) where feature maps are smaller (e.g., 16x16 or 8x8) and computational cost is more manageable.
Here's how you might modify the U-Net structure, specifically within the encoder or decoder loops:
# Example snippet within a U-Net encoder loop
# ... previous layers ...
x = ResidualBlock(in_channels, out_channels, time_emb_dim)(x, t_emb)
# Add attention block if dimensions are suitable (e.g., lower resolution)
if x.shape[-1] <= 16: # Apply attention at resolutions 16x16 or lower
x = AttentionBlock(out_channels, num_heads=8)(x)
# ... potentially more residual blocks or pooling ...
Similarly, in the decoder:
# Example snippet within a U-Net decoder loop
# ... upsampling and concatenation ...
x = ResidualBlock(in_channels, out_channels, time_emb_dim)(x, t_emb)
# Add attention block consistent with encoder
if x.shape[-1] <= 16:
x = AttentionBlock(out_channels, num_heads=8)(x)
# ... more layers ...
The diagram below illustrates where an attention block might fit within a sequence of operations in a U-Net layer.
Diagram showing the flow through a standard ResidualBlock, potentially followed by an AttentionBlock before producing the final output for that layer.
Now that you have the building blocks, consider these points:
num_heads
parameter controls the granularity of the attention mechanism. More heads allow attending to different feature subspaces but increase computation slightly. Values like 4, 8, or 16 are common. Ensure the channel dimension is divisible by num_heads
.CrossAttentionBlock
instead of or in addition to self-attention. The key
and value
inputs would come from the conditioning context (e.g., projected text embeddings), while the query
would come from the image features x
. This is typically added in similar locations to self-attention.h*w
here). This is why they are often restricted to lower resolutions.GroupNorm
used here) and residual connections are used around attention blocks to maintain stable training.By integrating attention mechanisms, you equip your U-Net to better model complex spatial relationships and potentially integrate conditioning information more effectively, pushing the capabilities of your diffusion model. Remember to test these modifications carefully, monitoring both sample quality and training performance.
© 2025 ApX Machine Learning