Let's put the theory into practice by modifying a standard U-Net block to include attention mechanisms. As discussed earlier, adding attention, particularly self-attention, allows the model to capture long-range dependencies within the image features, which can be beneficial for generating coherent structures. We'll focus on adding a spatial self-attention layer within a residual block commonly found in U-Net architectures used for diffusion.We assume you have a working understanding of PyTorch and the basic components of a U-Net (convolutional layers, normalization, activation functions, residual connections). Our goal is to augment an existing block structure.Prerequisites: A Basic Residual BlockLet's start with a typical residual block often used in U-Nets. It might look something like this (simplified):import torch import torch.nn as nn import torch.nn.functional as F class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels, time_emb_dim=None): super().__init__() self.norm1 = nn.GroupNorm(32, in_channels) self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) # Optional time embedding projection self.time_mlp = None if time_emb_dim is not None: self.time_mlp = nn.Sequential( nn.SiLU(), nn.Linear(time_emb_dim, out_channels * 2) # For scale and shift ) self.norm2 = nn.GroupNorm(32, out_channels) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) self.activation = nn.SiLU() # Swish activation is common # Skip connection matching if in_channels != out_channels: self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1) else: self.skip_connection = nn.Identity() def forward(self, x, t_emb=None): residual = x h = self.norm1(x) h = self.activation(h) h = self.conv1(h) if self.time_mlp is not None and t_emb is not None: time_encoding = self.time_mlp(t_emb) time_encoding = time_encoding.view(h.shape[0], h.shape[1] * 2, 1, 1) scale, shift = torch.chunk(time_encoding, 2, dim=1) h = self.norm2(h) * (1 + scale) + shift # Modulate features else: h = self.norm2(h) # Apply norm if no time embedding modulation h = self.activation(h) h = self.conv2(h) return h + self.skip_connection(residual) This ResidualBlock includes Group Normalization, convolutional layers, an activation function (SiLU), and handles residual connections, along with optional time embedding modulation.Implementing a Multi-Head Self-Attention BlockNow, let's define a self-attention block suitable for image features. Since standard multi-head self-attention (MHSA) operates on sequences, we need to reshape the 2D spatial feature maps into sequences.import math class AttentionBlock(nn.Module): def __init__(self, channels, num_heads=8): super().__init__() assert channels % num_heads == 0, f"Channels ({channels}) must be divisible by num_heads ({num_heads})" self.num_heads = num_heads self.head_dim = channels // num_heads self.scale = 1 / math.sqrt(self.head_dim) self.norm = nn.GroupNorm(32, channels) # Query, Value projections self.to_qkv = nn.Conv2d(channels, channels * 3, kernel_size=1) # Output projection self.to_out = nn.Conv2d(channels, channels, kernel_size=1) def forward(self, x): b, c, h, w = x.shape residual = x qkv = self.to_qkv(self.norm(x)) # Reshape for attention: (b, c, h, w) -> (b, c, h*w) -> (b, num_heads, head_dim, h*w) -> (b*num_heads, head_dim, h*w) q, k, v = qkv.chunk(3, dim=1) q = q.reshape(b * self.num_heads, self.head_dim, h * w) k = k.reshape(b * self.num_heads, self.head_dim, h * w) v = v.reshape(b * self.num_heads, self.head_dim, h * w) # Scaled dot-product attention # (b*num_heads, head_dim, h*w) @ (b*num_heads, h*w, head_dim) -> (b*num_heads, h*w, h*w) attention_scores = torch.bmm(q.transpose(1, 2), k) * self.scale attention_probs = F.softmax(attention_scores, dim=-1) # (b*num_heads, h*w, h*w) @ (b*num_heads, head_dim, h*w)' -> (b*num_heads, h*w, head_dim) out = torch.bmm(attention_probs, v.transpose(1, 2)) # Reshape back: (b*num_heads, h*w, head_dim) -> (b, num_heads, h*w, head_dim) -> (b, c, h, w) out = out.transpose(1, 2).reshape(b, c, h, w) out = self.to_out(out) return out + residual # Add residual connectionImportant points in AttentionBlock:We use GroupNorm before attention, a common practice.to_qkv projects the input into Query, Key, and Value using a 1x1 convolution.The spatial dimensions (h, w) are flattened into a sequence length h*w.Standard scaled dot-product attention is performed.The output is reshaped back to (b, c, h, w).A final 1x1 convolution projects the features back.A residual connection is added.Integrating Attention into the U-NetAttention blocks are typically inserted after residual blocks, especially at lower spatial resolutions (deeper in the U-Net) where feature maps are smaller (e.g., 16x16 or 8x8) and computational cost is more manageable.Here's how you might modify the U-Net structure, specifically within the encoder or decoder loops:# Example snippet within a U-Net encoder loop # ... previous layers ... x = ResidualBlock(in_channels, out_channels, time_emb_dim)(x, t_emb) # Add attention block if dimensions are suitable (e.g., lower resolution) if x.shape[-1] <= 16: # Apply attention at resolutions 16x16 or lower x = AttentionBlock(out_channels, num_heads=8)(x) # ... potentially more residual blocks or pooling ...Similarly, in the decoder:# Example snippet within a U-Net decoder loop # ... upsampling and concatenation ... x = ResidualBlock(in_channels, out_channels, time_emb_dim)(x, t_emb) # Add attention block consistent with encoder if x.shape[-1] <= 16: x = AttentionBlock(out_channels, num_heads=8)(x) # ... more layers ...The diagram below illustrates where an attention block might fit within a sequence of operations in a U-Net layer.digraph G { rankdir=LR; node [shape=box, style=filled, fillcolor="#e9ecef", fontname="sans-serif"]; edge [fontname="sans-serif"]; Input [label="Input Features (x)"]; ResBlock [label="ResidualBlock", fillcolor="#a5d8ff"]; AttnBlock [label="AttentionBlock", fillcolor="#ffec99"]; Output [label="Output Features"]; Input -> ResBlock; ResBlock -> AttnBlock [label="If resolution is low"]; AttnBlock -> Output; ResBlock -> Output [label="If resolution is high"]; }Diagram showing the flow through a standard ResidualBlock, potentially followed by an AttentionBlock before producing the final output for that layer.Experimentation and ApproachesNow that you have the building blocks, consider these points:Placement: Experiment with placing attention blocks at different resolutions. Adding them too early (high resolution) can be computationally expensive. Adding them only at the bottleneck might limit their impact. Common choices include adding them after each residual block in the lower-resolution layers.Number of Heads: The num_heads parameter controls the granularity of the attention mechanism. More heads allow attending to different feature subspaces but increase computation slightly. Values like 4, 8, or 16 are common. Ensure the channel dimension is divisible by num_heads.Cross-Attention: For conditional models (e.g., text-to-image), you would implement a CrossAttentionBlock instead of or in addition to self-attention. The key and value inputs would come from the conditioning context (e.g., projected text embeddings), while the query would come from the image features x. This is typically added in similar locations to self-attention.Computational Cost: Attention mechanisms, especially self-attention, have a computational cost quadratic in the sequence length (which is h*w here). This is why they are often restricted to lower resolutions.Training Stability: Ensure appropriate normalization (like GroupNorm used here) and residual connections are used around attention blocks to maintain stable training.By integrating attention mechanisms, you equip your U-Net to better model complex spatial relationships and potentially integrate conditioning information more effectively, pushing the capabilities of your diffusion model. Remember to test these modifications carefully, monitoring both sample quality and training performance.