Building a single Diffusion Transformer (DiT) block is a practical exercise that illustrates its core components. This implementation clarifies how conditioning information (like timestep embeddings) is integrated and how self-attention operates on image patch embeddings within the diffusion framework, demonstrating how DiTs adapt transformer principles for image generation.We'll build this block using PyTorch, assuming you have a working knowledge of PyTorch's nn.Module and standard transformer components like Layer Normalization, Multi-Head Self-Attention, and MLP layers.Components of a DiT BlockRecall from our discussion of the DiT architecture that it processes a sequence of image patch embeddings along with conditioning information. A standard DiT block typically includes:Input: A sequence of token embeddings x (shape [batch_size, num_patches, hidden_dim]) and a conditioning vector c (shape [batch_size, embedding_dim], often derived from timestep and possibly class labels).Adaptive Layer Normalization (adaLN / adaLN-Zero): Layer Normalization layers whose scale and shift parameters are dynamically generated from the conditioning vector c. The "Zero" variant initializes the projection producing these parameters such that the block initially behaves like an identity function, promoting stability.Multi-Head Self-Attention (MHSA): Allows tokens (patches) to attend to each other, capturing spatial relationships.MLP Block: A standard feed-forward network (usually two linear layers with a non-linearity like GeLU or SiLU) applied independently to each token.Residual Connections: Summing the input of a sub-layer (like MHSA or MLP) to its output, facilitating gradient flow.The specific arrangement often follows the pre-normalization style common in transformers:Input -> adaLN -> MHSA -> Residual Add -> adaLN -> MLP -> Residual Add -> OutputImplementing the adaLN-Zero ModulationA significant aspect of DiT is how conditioning c modulates the processing of image tokens x. This is achieved through the adaLN-Zero mechanism. We need a small sub-network that takes the conditioning vector c and outputs parameters for shifting, scaling, and gating the normalized activations.Let's define a helper function or module for this. It takes c and produces shift, scale, and gate parameters. Since adaLN-Zero applies before both the MHSA and MLP layers, we'll need separate sets of these parameters for each. The output dimension should be 6 * hidden_dim (shift, scale, gate for MHSA; shift, scale, gate for MLP).import torch import torch.nn as nn class AdaLNModulation(nn.Module): """ Calculates shift, scale, and gate parameters from conditioning embeddings. Initializes the final linear layer's weights and biases to zero. """ def __init__(self, embedding_dim: int, hidden_dim: int): super().__init__() self.silu = nn.SiLU() # Initialize the linear layer with zeros for the "Zero" aspect self.linear = nn.Linear(embedding_dim, 6 * hidden_dim, bias=True) nn.init.zeros_(self.linear.weight) nn.init.zeros_(self.linear.bias) def forward(self, c): # c shape: [batch_size, embedding_dim] mod_params = self.linear(self.silu(c)) # mod_params shape: [batch_size, 6 * hidden_dim] # Split into 6 parts for shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp # Each part has shape [batch_size, hidden_dim] # Add unsqueeze(1) to make them broadcastable with token dim: [batch_size, 1, hidden_dim] return mod_params.chunk(6, dim=1) Building the DiT BlockNow we can assemble the DiTBlock. We'll use standard PyTorch modules for LayerNorm, MultiheadAttention, and the MLP.import torch import torch.nn as nn # Assume AdaLNModulation class from above is defined class DiTBlock(nn.Module): """ A standard DiT block with adaLN-Zero modulation. """ def __init__(self, hidden_dim: int, embedding_dim: int, num_heads: int, mlp_ratio: float = 4.0): super().__init__() self.hidden_dim = hidden_dim self.num_heads = num_heads # Normalization layers (using LayerNorm without elementwise affine, # as adaLN provides the affine parameters) self.norm1 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6) self.norm2 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6) # Modulation network to generate adaLN parameters self.modulation = AdaLNModulation(embedding_dim, hidden_dim) # Attention layer self.attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True) # MLP layer mlp_hidden_dim = int(hidden_dim * mlp_ratio) self.mlp = nn.Sequential( nn.Linear(hidden_dim, mlp_hidden_dim), nn.GELU(), # Or nn.SiLU() nn.Linear(mlp_hidden_dim, hidden_dim), ) def forward(self, x, c): # x shape: [batch_size, num_patches, hidden_dim] # c shape: [batch_size, embedding_dim] # Calculate modulation parameters (shift, scale, gate for MSA and MLP) # Each param shape: [batch_size, 1, hidden_dim] after unsqueeze shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [ p.unsqueeze(1) for p in self.modulation(c) ] # --- Apply adaLN-Zero before MHSA --- x_norm1 = self.norm1(x) # Modulate: Scale, Shift x_modulated1 = x_norm1 * (1 + scale_msa) + shift_msa # Apply attention attn_output, _ = self.attn(x_modulated1, x_modulated1, x_modulated1) # Apply gating and add residual connection x = x + gate_msa * attn_output # --- Apply adaLN-Zero before MLP --- x_norm2 = self.norm2(x) # Modulate: Scale, Shift x_modulated2 = x_norm2 * (1 + scale_mlp) + shift_mlp # Apply MLP mlp_output = self.mlp(x_modulated2) # Apply gating and add residual connection x = x + gate_mlp * mlp_output # Output shape: [batch_size, num_patches, hidden_dim] return x Code WalkthroughInitialization (__init__):We initialize two LayerNorm modules (norm1, norm2) with elementwise_affine=False. This is important because the adaptive scale and shift will be provided by our modulation network, not learned statically within the LayerNorm layer itself.The AdaLNModulation instance (modulation) is created to compute the 6 necessary parameters (shift, scale, gate for both MHSA and MLP paths) from the conditioning vector c.A standard MultiheadAttention layer (attn) is defined. batch_first=True makes it easier to work with inputs of shape [batch, sequence, feature].A standard MLP block (mlp) is created, typically expanding the hidden dimension by a factor (mlp_ratio) and then projecting back. GELU is a common activation function here.Forward Pass (forward):The conditioning vector c is passed through the modulation network to get the six parameter sets. We unsqueeze(1) to make their shape [batch_size, 1, hidden_dim] so they broadcast correctly across the sequence length (number of patches) dimension during modulation.First Sub-layer (Attention):The input x is normalized using norm1.The normalized x_norm1 is modulated using the scale_msa and shift_msa parameters: x_modulated1 = x_norm1 * (1 + scale_msa) + shift_msa. Note the (1 + scale) formulation, common in adaptive normalization schemes.The modulated tensor x_modulated1 is fed into the self-attention layer (attn).The output of the attention layer is gated by gate_msa and added back to the original input x (residual connection): x = x + gate_msa * attn_output.Second Sub-layer (MLP):The output from the attention sub-layer (x) is normalized using norm2.This normalized tensor x_norm2 is modulated using scale_mlp and shift_mlp.The modulated tensor x_modulated2 is passed through the mlp.The MLP output is gated by gate_mlp and added back to the input of this sub-layer (residual connection): x = x + gate_mlp * mlp_output.The final tensor x is returned.Using the BlockTo build a full DiT, you would stack multiple instances of this DiTBlock. The input x would initially be the patch embeddings plus positional embeddings, and c would be derived from the diffusion timestep (and potentially class labels or other conditioning). The output of the final block would then typically be projected to predict the noise (or the original data x_0, depending on the parameterization).# Example Usage (Illustrative) batch_size = 4 num_patches = 196 # e.g., for a 224x224 image with 16x16 patches hidden_dim = 768 embedding_dim = 256 # Dimension of timestep/class embedding num_heads = 12 # Sample input tensors x_patches = torch.randn(batch_size, num_patches, hidden_dim) conditioning_vec = torch.randn(batch_size, embedding_dim) # Instantiate the block dit_block = DiTBlock(hidden_dim, embedding_dim, num_heads) # Pass inputs through the block output_tokens = dit_block(x_patches, conditioning_vec) print(f"Input shape: {x_patches.shape}") print(f"Conditioning shape: {conditioning_vec.shape}") print(f"Output shape: {output_tokens.shape}") # Expected Output: # Input shape: torch.Size([4, 196, 768]) # Conditioning shape: torch.Size([4, 256]) # Output shape: torch.Size([4, 196, 768])This hands-on implementation provides a concrete view of how transformer blocks are adapted within the DiT framework, particularly focusing on the adaLN-Zero mechanism for integrating conditioning information effectively. Building upon this single block, one can construct the entire transformer backbone for a diffusion model.