实现 Diffusion Transformer (DiT) 需要留意多个实用细节。尽管用 Transformer 块替换 U-Net 骨干网络,在建模长距离依赖性方面可能带来优势,但这也引出计算成本、数据处理和训练稳定性方面的特定难题与考量。构建或训练 DiT 模型时需要处理的重要方面将被检查。
计算成本与可扩展性
Transformer 的主要组成是自注意力机制。标准自注意力机制的计算和内存复杂度为 O(N2),其中 N 是序列长度。在 DiT 处理图像数据的情况下,N 对应于图像被划分的补丁数量。对于分辨率为 H×W、补丁大小为 P×P 的图像,补丁数量为 N=(H×W)/P2。
这种二次方扩展意味着将图像分辨率加倍(像素数量增加四倍)或将补丁大小减半(补丁数量增加四倍),都会使注意力计算的成本增加约 16 倍。这显著影响训练时间和 GPU 内存需求,特别是对于高分辨率图像生成。尽管像 FlashAttention 这样的技术可以优化注意力实现,但它们并未改变基本的二次方复杂度。这与基于 CNN 的 U-Net 形成对比,在 U-Net 中,卷积操作通常与像素数量呈线性关系,即 O(Npixels)。因此,选择补丁大小和管理序列长度是在设计和训练 DiT 时要点考量。
块嵌入策略
Transformer 处理的是令牌序列。要将它们应用于图像,输入图像 xt 在时间步 t 必须转换为这样的序列:
- 图像分块: 图像张量(例如,形状为
[批次, 通道, 高度, 宽度])被划分为不重叠的块网格。对于大小为 256×256、块大小为 16×16 的图像,您将得到 (256/16)×(256/16)=16×16=256 个块。
- 线性投影: 每个块(例如,形状为
[通道, P, P])被展平并线性投影到一个维度为 D(Transformer 的隐藏维度)的嵌入向量。这会产生一个包含 N 个嵌入向量的序列,通常形状为 [批次, N, D]。
- 位置嵌入: 由于自注意力机制是置换不变的,模型需要关于每个块原始空间位置的信息。位置嵌入被添加到块嵌入中。这些可以在训练期间学习,也可以是固定的(例如,二维正弦嵌入)。应用于序列顺序的标准一维位置嵌入在许多实现中都很常见。
块大小 P 的选择影响很大。
- 较小的 P: 导致序列 N 更长,二次方地增加计算成本。但是,它允许模型通过初始投影潜在地捕获每个块内更精细的细节。
- 较大的 P: 减少 N 和计算成本,但在 Transformer 层处理它们之前,可能会将块内的重要局部细节平均化。
Transformer 块设计与条件作用
标准 DiT 架构使用一系列 Transformer 块。每个块通常包含层归一化 (LN)、多头自注意力 (MHSA) 和一个 MLP(通常是两个带有 GeLU 等激活函数的线性层)。DiT 中的一个重要创新是如何使用自适应层归一化,特别是 adaLN-Zero,来融入时间步 t 和条件信息 c(例如类别标签)。
adaLN-Zero 不是简单地将时间步和条件嵌入添加到序列中,而是调制 Transformer 子块的输出。对于隐藏状态 h,adaLN-Zero 操作是:
adaLN-Zero(h,γ,β,α)=α⋅LayerNorm(h)+β
这里,γ(LayerNorm 内部用于缩放),β(偏移),和 α(输出缩放)由一个小型 MLP 动态计算,该 MLP 以时间步 t 和条件 c 的嵌入作为输入。这些嵌入通常首先进行处理:
- 时间步 t 使用正弦特征后接一个 MLP 转换为嵌入。
- 条件 c(例如,一个类别索引)被映射到一个学习到的嵌入向量。
- 这两个嵌入相加:emb=MLP(sinusoidal(t))+Embedding(c)。
- 最终的 MLP 预测参数:(γ,β,α)=MLPadaLN(emb)。
这些自适应参数应用于特定点,通常在每个 Transformer 块内的 MHSA 和 MLP 层之前,有时也用于调制残差连接。“Zero”部分指的是将生成 α 和 β(并影响 γ)的最终 MLP 层初始化为输出零。这意味着这些自适应层在初始阶段充当恒等函数,有助于训练稳定性,尤其是在早期。
训练稳定性与优化
训练 DiT 等大型 Transformer 模型需要仔细优化:
- 优化器: AdamW 是一个常见选择,通常使用 β1=0.9,β2=0.999。
- 学习率: 适当的学习率调度(例如,带预热的余弦衰减)很重要。峰值学习率可能在 1e−4 到 5e−4 的范围,具体取决于模型大小和批次大小。
- 权重衰减: 应用于稳定训练并改进泛化能力。
- 混合精度训练: 使用 FP16(16 位浮点)或 BF16(脑浮点)对于管理内存使用和加速现代 GPU(如 NVIDIA Tensor Cores)上的训练几乎是不可或缺的。这需要梯度缩放以防止低精度格式的数值下溢/溢出问题。PyTorch 的
torch.cuda.amp 或 TensorFlow 的混合精度 API 大致上能自动处理此问题。
- 梯度裁剪: 有时需要防止梯度爆炸,尤其是在训练早期或批次较大时。通常按全局范数裁剪。
- EMA(指数移动平均): 相比直接使用优化器中的原始权重,维护模型权重的 EMA 通常能使最终评估的模型表现更好。EMA 权重通常在推理/采样期间使用。
模型大小选择与扩展
最初的 DiT 论文显示,这些模型表现出可预测的扩展特性。性能,以 FID (Fréchet Inception Distance) 等指标衡量,通常随模型大小(参数数量、深度、宽度)和计算预算的增加而提高。常见配置包括:
- DiT-S (小型): 块较少,隐藏维度 D 较小。
- DiT-B (基础型): 中等大小。
- DiT-L (大型): 块更多, D 较大。
- DiT-XL (超大型): 更大,需要大量计算。
下表概括了扩展的权衡(数值为示例):
| 模型 |
参数量(百万) |
相对计算量 |
潜在 FID(越低越好) |
| DiT-S |
~30 |
1x |
中等 |
| DiT-B |
~100 |
3-4x |
良好 |
| DiT-L |
~400 |
10-15x |
很好 |
| DiT-XL |
~600+ |
20-25x |
目前最佳 |
这种扩展行为使得研究人员可以估算通过投入更多计算资源可获得的性能提升。
Diffusion Transformer 模型大小、大致参数数量与潜在 FID 分数改进(分数越低表明图像质量越好)之间的关系。计算需求随模型大小显著增长。
实现实用建议
- 借鉴现有实现: 从头开始构建 DiT 是复杂的。仔细查看受认可的开源实现(例如原始 DiT 仓库或
diffusers 等框架中的实现),以理解实用选择。
- 验证形状: 在分块、嵌入、注意力以及反分块阶段,仔细追踪张量形状。形状不匹配是常见的错误源。
- 从小规模开始: 在扩展之前,先用较小的模型(如 DiT-S)和较低分辨率进行实验。这能加快调试周期。
- 监控训练: 使用 TensorBoard 或 Weights & Biases 等工具,在整个训练过程中追踪损失曲线、梯度范数和生成的图像样本。这有助于诊断发散或收敛缓慢等问题。
- 硬件: 对硬件需求要有切实际的认知。有效训练更大的 DiT 通常需要多块高显存 GPU(例如 A100、H100)和分布式训练设置(例如使用 PyTorch 的 DistributedDataParallel)。
通过仔细考量这些计算、架构和优化方面的要点,您可以成功实现和训练 Diffusion Transformer 模型,用于高质量图像生成任务。