Having explored the architecture of Diffusion Transformers (DiTs) and how they adapt transformer principles for image generation, let's solidify our understanding by implementing a core component: a single DiT block. This practical exercise demonstrates how conditioning information (like timestep embeddings) is integrated and how self-attention operates on image patch embeddings within the diffusion framework.
We'll build this block using PyTorch, assuming you have a working knowledge of PyTorch's nn.Module
and standard transformer components like Layer Normalization, Multi-Head Self-Attention, and MLP layers.
Recall from our discussion of the DiT architecture that it processes a sequence of image patch embeddings along with conditioning information. A standard DiT block typically includes:
x
(shape [batch_size, num_patches, hidden_dim]
) and a conditioning vector c
(shape [batch_size, embedding_dim]
, often derived from timestep and possibly class labels).c
. The "Zero" variant initializes the projection producing these parameters such that the block initially behaves like an identity function, promoting stability.The specific arrangement often follows the pre-normalization style common in transformers:
Input -> adaLN -> MHSA -> Residual Add -> adaLN -> MLP -> Residual Add -> Output
A significant aspect of DiT is how conditioning c
modulates the processing of image tokens x
. This is achieved through the adaLN-Zero
mechanism. We need a small sub-network that takes the conditioning vector c
and outputs parameters for shifting, scaling, and gating the normalized activations.
Let's define a helper function or module for this. It takes c
and produces shift
, scale
, and gate
parameters. Since adaLN-Zero
applies before both the MHSA and MLP layers, we'll need separate sets of these parameters for each. The output dimension should be 6 * hidden_dim
(shift, scale, gate for MHSA; shift, scale, gate for MLP).
import torch
import torch.nn as nn
class AdaLNModulation(nn.Module):
"""
Calculates shift, scale, and gate parameters from conditioning embeddings.
Initializes the final linear layer's weights and biases to zero.
"""
def __init__(self, embedding_dim: int, hidden_dim: int):
super().__init__()
self.silu = nn.SiLU()
# Initialize the linear layer with zeros for the "Zero" aspect
self.linear = nn.Linear(embedding_dim, 6 * hidden_dim, bias=True)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
def forward(self, c):
# c shape: [batch_size, embedding_dim]
mod_params = self.linear(self.silu(c))
# mod_params shape: [batch_size, 6 * hidden_dim]
# Split into 6 parts for shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp
# Each part has shape [batch_size, hidden_dim]
# Add unsqueeze(1) to make them broadcastable with token dim: [batch_size, 1, hidden_dim]
return mod_params.chunk(6, dim=1)
Now we can assemble the DiTBlock
. We'll use standard PyTorch modules for LayerNorm
, MultiheadAttention
, and the MLP.
import torch
import torch.nn as nn
# Assume AdaLNModulation class from above is defined
class DiTBlock(nn.Module):
"""
A standard DiT block with adaLN-Zero modulation.
"""
def __init__(self, hidden_dim: int, embedding_dim: int, num_heads: int, mlp_ratio: float = 4.0):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
# Normalization layers (using LayerNorm without elementwise affine,
# as adaLN provides the affine parameters)
self.norm1 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
self.norm2 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
# Modulation network to generate adaLN parameters
self.modulation = AdaLNModulation(embedding_dim, hidden_dim)
# Attention layer
self.attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
# MLP layer
mlp_hidden_dim = int(hidden_dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(hidden_dim, mlp_hidden_dim),
nn.GELU(), # Or nn.SiLU()
nn.Linear(mlp_hidden_dim, hidden_dim),
)
def forward(self, x, c):
# x shape: [batch_size, num_patches, hidden_dim]
# c shape: [batch_size, embedding_dim]
# Calculate modulation parameters (shift, scale, gate for MSA and MLP)
# Each param shape: [batch_size, 1, hidden_dim] after unsqueeze
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
p.unsqueeze(1) for p in self.modulation(c)
]
# --- Apply adaLN-Zero before MHSA ---
x_norm1 = self.norm1(x)
# Modulate: Scale, Shift
x_modulated1 = x_norm1 * (1 + scale_msa) + shift_msa
# Apply attention
attn_output, _ = self.attn(x_modulated1, x_modulated1, x_modulated1)
# Apply gating and add residual connection
x = x + gate_msa * attn_output
# --- Apply adaLN-Zero before MLP ---
x_norm2 = self.norm2(x)
# Modulate: Scale, Shift
x_modulated2 = x_norm2 * (1 + scale_mlp) + shift_mlp
# Apply MLP
mlp_output = self.mlp(x_modulated2)
# Apply gating and add residual connection
x = x + gate_mlp * mlp_output
# Output shape: [batch_size, num_patches, hidden_dim]
return x
Initialization (__init__
):
LayerNorm
modules (norm1
, norm2
) with elementwise_affine=False
. This is important because the adaptive scale and shift will be provided by our modulation
network, not learned statically within the LayerNorm
layer itself.AdaLNModulation
instance (modulation
) is created to compute the 6 necessary parameters (shift, scale, gate for both MHSA and MLP paths) from the conditioning vector c
.MultiheadAttention
layer (attn
) is defined. batch_first=True
makes it easier to work with inputs of shape [batch, sequence, feature]
.mlp
) is created, typically expanding the hidden dimension by a factor (mlp_ratio
) and then projecting back. GELU is a common activation function here.Forward Pass (forward
):
c
is passed through the modulation
network to get the six parameter sets. We unsqueeze(1)
to make their shape [batch_size, 1, hidden_dim]
so they broadcast correctly across the sequence length (number of patches) dimension during modulation.x
is normalized using norm1
.x_norm1
is modulated using the scale_msa
and shift_msa
parameters: x_modulated1 = x_norm1 * (1 + scale_msa) + shift_msa
. Note the (1 + scale)
formulation, common in adaptive normalization schemes.x_modulated1
is fed into the self-attention layer (attn
).gate_msa
and added back to the original input x
(residual connection): x = x + gate_msa * attn_output
.x
) is normalized using norm2
.x_norm2
is modulated using scale_mlp
and shift_mlp
.x_modulated2
is passed through the mlp
.gate_mlp
and added back to the input of this sub-layer (residual connection): x = x + gate_mlp * mlp_output
.x
is returned.To build a full DiT, you would stack multiple instances of this DiTBlock
. The input x
would initially be the patch embeddings plus positional embeddings, and c
would be derived from the diffusion timestep (and potentially class labels or other conditioning). The output of the final block would then typically be projected to predict the noise (or the original data x_0
, depending on the parameterization).
# Example Usage (Illustrative)
batch_size = 4
num_patches = 196 # e.g., for a 224x224 image with 16x16 patches
hidden_dim = 768
embedding_dim = 256 # Dimension of timestep/class embedding
num_heads = 12
# Sample input tensors
x_patches = torch.randn(batch_size, num_patches, hidden_dim)
conditioning_vec = torch.randn(batch_size, embedding_dim)
# Instantiate the block
dit_block = DiTBlock(hidden_dim, embedding_dim, num_heads)
# Pass inputs through the block
output_tokens = dit_block(x_patches, conditioning_vec)
print(f"Input shape: {x_patches.shape}")
print(f"Conditioning shape: {conditioning_vec.shape}")
print(f"Output shape: {output_tokens.shape}")
# Expected Output:
# Input shape: torch.Size([4, 196, 768])
# Conditioning shape: torch.Size([4, 256])
# Output shape: torch.Size([4, 196, 768])
This hands-on implementation provides a concrete view of how transformer blocks are adapted within the DiT framework, particularly focusing on the adaLN-Zero
mechanism for integrating conditioning information effectively. Building upon this single block, one can construct the entire transformer backbone for a diffusion model.
© 2025 ApX Machine Learning