扩散变换器 (DiT) 架构代表了设计扩散模型主干网络方面的一个重要转变。它将变换器在图像数据上的成功应用(例如视觉变换器 (ViT))的原理引入模型设计中。由 Peebles 和 Xie (2022) 提出,DiT 用纯变换器架构替代了常用的卷积 U-Net,展现出强大的性能和可扩展性,尤其在图像生成任务中。
其基本想法是,将时间步 t 对含噪图像 xt 进行操作的扩散过程,视为一个适合变换器的序列建模问题。DiT 不依赖卷积层归纳偏置(如局部性和平移等变性),而是运用变换器对图像中远距离依赖性进行建模的能力。
输入处理:为变换器准备图像
类似 ViT,DiT 无法直接处理原始像素网格。输入含噪图像 xt∈RH×W×C 必须首先转换成一系列令牌:
- 图像分块: 输入图像 xt 被划分为不重叠的图像块网格。对于高为 H、宽为 W 的图像,使用 P×P 的图像块尺寸会产生 N=(H×W)/P2 个图像块。
- 线性嵌入: 每个图像块被展平为一个向量,然后线性投影成维度为 D 的令牌嵌入。这就形成了一个由 N 个令牌组成的初始序列:z0∈RN×D。
- 位置嵌入: 由于变换器具有排列不变性,必须显式地添加位置信息。标准的、学习得到的 1D 或 2D 位置嵌入 Epos∈RN×D 会被添加到图像块嵌入中:z0′=z0+Epos。这个序列 z0′ 构成了变换器块的输入。
核心变换器块
DiT 的主体由 L 个变换器块组成。每个块处理令牌序列,细化表示。一个标准 DiT 块通常包括:
- 层归一化 (LN): 应用于注意力层或 MLP 层之前,以稳定激活。
- 多头自注意力 (MHSA): 允许序列中的每个令牌关注所有其他令牌,捕获图像块之间的全局关联。
- 层归一化 (LN): 再次应用于前馈网络之前。
- 前馈网络 (FFN): 通常是一个简单的多层感知机 (MLP),包含两个线性层和一个非线性激活函数(例如 GeLU),独立应用于每个令牌。
MHSA 和 FFN 子层周围都使用了残差连接,保证梯度平滑流动,并使深度变换器的训练成为可能。
zl′′=MHSA(LN(zl−1′))+zl−1′zl′=FFN(LN(zl′′))+zl′′
在此,zl−1′ 是第 l 个块的输入,zl′ 是其输出。
条件化:引入时间步和上下文
扩散模型必须以当前时间步 t 为条件,并且可能以其他上下文信息 c(如类别标签、文本嵌入等)为条件。DiT 需要机制将这些条件信息注入到变换器块中。
- 时间和条件嵌入: 时间步 t 首先被转换成一个向量嵌入 et,通常使用正弦嵌入后接一个 MLP。类似地,上下文 c 被映射到一个嵌入 ec。
- 自适应层归一化 (adaLN / adaLN-Zero): 这是 DiT 中一种常见技术。不同于标准的层归一化,自适应归一化层会根据嵌入 et 和 ec 动态计算尺度 (γ) 和偏移 (β) 参数。这些参数调节每个变换器块内的归一化激活,通常在 MHSA 和 FFN 层之前。
自适应层归一化(h,et,ec)=(γ)⋅层归一化(h)+(β)
其中 γ 和 β 是通过线性层投影 et 和 ec(通常是拼接或求和)得到的。“adaLN-Zero”变体初始化生成 γ 和 β 的投影层,使得初始输出接近于恒等映射,这可以通过保证条件化块在开始时表现得像残差连接来帮助训练稳定性。
- 其他方法: 尽管 adaLN 很常用,但也有其他选择:
- 将 et 和 ec 直接添加到令牌序列中(例如,将它们作为额外的令牌拼接,或添加到每个令牌嵌入中)。
- 如果条件 c 本身是一个序列(例如文本嵌入),则使用交叉注意力机制。
输出处理:预测扩散目标
经过 L 个变换器块处理后,输出令牌序列 zL′ 需要转换回扩散目标所需的格式(通常是预测在步骤 t 添加的噪声 ϵ,或预测原始图像 x0)。
- 最终调节: 条件嵌入 (et,ec) 通常会使用一个最终的 adaLN 层或类似机制,对 zL′ 进行最后一次输出调节。
- 线性投影: 一个最终的线性层将每个输出令牌嵌入投影回展平图像块的维度(例如,P×P×C)。
- 反分块 / 重塑: 输出令牌被重新排列(“反分块”)以重建最终的输出张量,该张量与输入图像具有相同的维度,表示预测的噪声 ϵθ(xt,t,c) 或 x0。
总体架构图
下图阐明了扩散变换器内部信息的高层次流程:
扩散变换器 (DiT) 中的整体数据流。输入图像 xt 被分块并嵌入。变换器块处理这些令牌,并通过 adaLN 由时间 t 和上下文 c 嵌入进行调节。最终令牌被投影并重塑,以生成扩散目标(例如,噪声 ϵθ)。
DiT 为 U-Net 提供了有力的替代方案。通过操作图像块序列,它们可以有效地对全局图像结构进行建模。它们的可扩展性(通过更大模型带来的性能提升得到证实)与大型语言模型中观察到的趋势一致,使其成为扩散生成建模持续发展中的重要架构。后续章节将比较 DiT 和 U-Net,并讨论实际实现细节。