趋近智
构建一个单个的扩散Transformer(DiT)模块是一个实践练习,它体现了其核心组成部分。这个实现说明了调节信息(如时间步嵌入)是如何整合的,以及自注意力如何在扩散框架内作用于图像块嵌入,从而阐明了DiTs如何调整Transformer原理用于图像生成。
我们将使用PyTorch来构建这个模块,这假定您对PyTorch的nn.Module和标准Transformer组件(如层归一化、多头自注意力以及MLP层)有实际的操作知识。
回顾我们对DiT架构的讨论,它处理图像块嵌入序列以及调节信息。一个标准DiT模块通常包含:
x(形状 [batch_size, num_patches, hidden_dim])和一个调节向量 c(形状 [batch_size, embedding_dim],通常源自时间步以及可能的类别标签)。c 动态生成的层归一化层。“Zero”变体初始化生成这些参数的投影,使得模块在初始时表现得像一个恒等函数,这有助于提升稳定性。具体的排列方式常遵循Transformer中常见的预归一化风格:
Input -> adaLN -> MHSA -> Residual Add -> adaLN -> MLP -> Residual Add -> Output
DiT的一个重要特点是调节信息 c 如何调制图像令牌 x 的处理过程。这是通过adaLN-Zero机制完成的。我们需要一个小型子网络,它接收调节向量 c 并输出用于偏移、缩放和门控归一化激活的参数。
我们来为之定义一个辅助函数或模块。它接收 c 并产生 shift、scale 和 gate 参数。由于adaLN-Zero在MHSA和MLP层 之前 应用,我们需要为它们分别准备参数集。输出维度应为 6 * hidden_dim(MHSA的偏移、缩放、门控;MLP的偏移、缩放、门控)。
import torch
import torch.nn as nn
class AdaLNModulation(nn.Module):
"""
从调节嵌入计算偏移、缩放和门控参数。
将最终线性层的权重和偏置初始化为零。
"""
def __init__(self, embedding_dim: int, hidden_dim: int):
super().__init__()
self.silu = nn.SiLU()
# 将线性层初始化为零,以实现“Zero”特性
self.linear = nn.Linear(embedding_dim, 6 * hidden_dim, bias=True)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
def forward(self, c):
# c 的形状: [batch_size, embedding_dim]
mod_params = self.linear(self.silu(c))
# mod_params 的形状: [batch_size, 6 * hidden_dim]
# 拆分为6部分,分别用于 shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp
# 每个部分的形状为 [batch_size, hidden_dim]
# 添加 unsqueeze(1) 使其能与令牌维度广播:[batch_size, 1, hidden_dim]
return mod_params.chunk(6, dim=1)
现在我们可以组装DiTBlock了。我们将使用标准的PyTorch模块来构建LayerNorm、MultiheadAttention和MLP。
import torch
import torch.nn as nn
# 假定AdaLNModulation类已在上方定义
class DiTBlock(nn.Module):
"""
一个带有adaLN-Zero调制的标准DiT模块。
"""
def __init__(self, hidden_dim: int, embedding_dim: int, num_heads: int, mlp_ratio: float = 4.0):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
# 归一化层(使用LayerNorm,不带元素级仿射,
# 因为adaLN提供了仿射参数)
self.norm1 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
self.norm2 = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6)
# 生成adaLN参数的调制网络
self.modulation = AdaLNModulation(embedding_dim, hidden_dim)
# 注意力层
self.attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
# MLP层
mlp_hidden_dim = int(hidden_dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(hidden_dim, mlp_hidden_dim),
nn.GELU(), # 或 nn.SiLU()
nn.Linear(mlp_hidden_dim, hidden_dim),
)
def forward(self, x, c):
# x 的形状: [batch_size, num_patches, hidden_dim]
# c 的形状: [batch_size, embedding_dim]
# 计算调制参数(MSA和MLP的偏移、缩放、门控)
# 每个参数的形状:在 unsqueeze 后为 [batch_size, 1, hidden_dim]
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
p.unsqueeze(1) for p in self.modulation(c)
]
# --- 在MHSA之前应用adaLN-Zero ---
x_norm1 = self.norm1(x)
# 调制:缩放,偏移
x_modulated1 = x_norm1 * (1 + scale_msa) + shift_msa
# 应用注意力
attn_output, _ = self.attn(x_modulated1, x_modulated1, x_modulated1)
# 应用门控并添加残差连接
x = x + gate_msa * attn_output
# --- 在MLP之前应用adaLN-Zero ---
x_norm2 = self.norm2(x)
# 调制:缩放,偏移
x_modulated2 = x_norm2 * (1 + scale_mlp) + shift_mlp
# 应用MLP
mlp_output = self.mlp(x_modulated2)
# 应用门控并添加残差连接
x = x + gate_mlp * mlp_output
# 输出形状: [batch_size, num_patches, hidden_dim]
return x
初始化 (__init__):
LayerNorm模块(norm1、norm2),并将elementwise_affine设置为False。这一点很重要,因为自适应的缩放和偏移参数将由我们的modulation网络提供,而不是在LayerNorm层内部静态学习。AdaLNModulation实例(modulation)是为了从调节向量 c 计算出6个必需的参数(MHSA和MLP路径的偏移、缩放、门控)。MultiheadAttention层(attn)。batch_first=True使得处理形状为[batch, sequence, feature]的输入更加方便。mlp),它通常会将隐藏维度按一个系数(mlp_ratio)扩展,然后投影回去。GELU是这里常用的激活函数。前向传播 (forward):
c 通过modulation网络,得到六组参数。我们对它们进行unsqueeze(1)操作,使其形状变为[batch_size, 1, hidden_dim],以便在调制过程中能正确地在序列长度(图像块数量)维度上进行广播。x通过norm1进行归一化。x_norm1使用scale_msa和shift_msa参数进行调制:x_modulated1 = x_norm1 * (1 + scale_msa) + shift_msa。请注意(1 + scale)的表示形式,这在自适应归一化方案中很常见。x_modulated1被送入自注意力层(attn)。gate_msa进行门控,并加回到原始输入x(残差连接):x = x + gate_msa * attn_output。x)通过norm2进行归一化。x_norm2使用scale_mlp和shift_mlp进行调制。x_modulated2通过mlp。gate_mlp进行门控,并加回到该子层的输入(残差连接):x = x + gate_mlp * mlp_output。x。要构建一个完整的DiT,您将堆叠多个此DiTBlock的实例。输入x最初将是图像块嵌入加上位置嵌入,而c将从扩散时间步(以及可能的类别标签或其他调节信息)获取。最终模块的输出通常会投影来预测噪声(或者原始数据x_0,这取决于参数化方式)。
# 使用示例(说明性)
batch_size = 4
num_patches = 196 # 例如,对于224x224图像和16x16图像块
hidden_dim = 768
embedding_dim = 256 # 时间步/类别嵌入的维度
num_heads = 12
# 输入张量示例
x_patches = torch.randn(batch_size, num_patches, hidden_dim)
conditioning_vec = torch.randn(batch_size, embedding_dim)
# 实例化模块
dit_block = DiTBlock(hidden_dim, embedding_dim, num_heads)
# 将输入传递给模块
output_tokens = dit_block(x_patches, conditioning_vec)
print(f"Input shape: {x_patches.shape}")
print(f"Conditioning shape: {conditioning_vec.shape}")
print(f"Output shape: {output_tokens.shape}")
# 预期输出:
# Input shape: torch.Size([4, 196, 768])
# Conditioning shape: torch.Size([4, 256])
# Output shape: torch.Size([4, 196, 768])
这个动手实现提供了一个具体视角,说明Transformer模块如何在DiT框架内被调整,尤其侧重于adaLN-Zero机制如何有效地整合调节信息。在此单个模块的构建基础上,可以搭建扩散模型的整个Transformer主干网络。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造