To effectively guide the diffusion model's generation process using conditioning information y, we often need to adapt the underlying neural network architecture. The standard U-Net, designed to predict noise ϵθ(xt,t) based on the noisy input xt and timestep t, must be modified to also incorporate y. The goal is to make the noise prediction conditional, becoming ϵθ(xt,t,y).
Several strategies exist for integrating this conditioning information, with the choice often depending on the nature and dimensionality of y.
When the conditioning information y is relatively simple, such as a class label for image generation (e.g., "cat", "dog", corresponding to integer labels 0, 1), we can use straightforward techniques.
Embedding and Adding to Timestep Embeddings: A common approach, particularly effective with classifier-free guidance (CFG), is to treat the class label y similarly to the timestep t.
torch.nn.Embedding
in PyTorch). Let's call this ey.This approach effectively informs the network about the desired class at multiple levels of feature processing.
Diagram illustrating the process of combining timestep and class label embeddings before injecting them into U-Net blocks.
Concatenation: Another method involves concatenating the embedded conditioning information ey with the input xt along the channel dimension. This augmented input is then fed into the U-Net. Alternatively, ey could be spatially broadcast (tiled to match the spatial dimensions of an intermediate feature map) and concatenated with feature maps deeper within the network. While simple, this might not always effectively propagate the conditioning signal throughout the complex U-Net structure compared to the additive embedding method.
When conditioning on richer, higher-dimensional information like text descriptions, simple addition or concatenation is usually insufficient. Text requires capturing sequential dependencies and nuanced meanings. For this, cross-attention has become the standard mechanism, forming the backbone of modern text-to-image models like Stable Diffusion.
Cross-Attention Mechanism
Recall that self-attention layers within a U-Net (often in Transformer-style blocks) allow different spatial locations in the image representation to attend to each other. Cross-attention layers work similarly but allow the image representation to attend to the conditioning information y.
Here's how it typically works within a U-Net block designed for conditional generation (e.g., text-to-image):
Inputs: The layer receives two main inputs:
Query, Key, Value:
Attention Calculation: The core operation calculates attention scores based on the similarity between Queries (from image) and Keys (from conditioning). A common way is scaled dot-product attention:
Attention(Q,K,V)=softmax(dkQKT)VHere, dk is the dimension of the key vectors. The softmax ensures the weights sum to 1.
Output: The result is a weighted sum of the Value vectors (derived from conditioning c), where the weights are determined by how relevant each part of the conditioning sequence is to each part of the image representation. This output is then typically added back to the original image features z (often via a residual connection) or further processed within the U-Net block.
Integration into U-Net
Cross-attention layers are usually inserted into multiple blocks of the U-Net, particularly within the downsampling, bottleneck, and upsampling paths. This allows the conditioning information y to influence the denoising process at various levels of feature abstraction.
Diagram showing the integration of conditioning embeddings (c) into image features (z) using a cross-attention layer within a U-Net block. Queries are derived from z, while Keys and Values come from c.
Example: Text-to-Image Generation
In a text-to-image model:
By modifying the U-Net architecture to incorporate conditioning information y, particularly using mechanisms like additive embeddings or cross-attention, diffusion models gain the ability to perform controlled generation, producing outputs tailored to specific requirements like class labels or detailed text descriptions.
© 2025 ApX Machine Learning