逆转生成式AI中的扩散过程,需要一个能够准确估计添加到图像$x_0$中以在特定时间步$t$生成带噪声版本$x_t$的噪声的模型。当给定$x_t$和$t$时,模型的任务是预测为创建$x_t$而从高斯分布中采样的噪声$\epsilon$。为此预测任务选择的标准神经网络架构,特别是在图像生成方面,是U-Net。U-Net架构最初为生物医学图像分割而开发,已证明对扩散模型非常有效。其设计非常适合输入和输出具有相同空间维度(如图像及其对应的噪声图)且同时保持精细空间细节并考虑较广背景的任务。U-Net因其可视化时的U形特征而得名。它由三个主要部分组成:下采样路径(编码器)瓶颈上采样路径(解码器)重要的一点是,它还采用跳跃连接,在下采样和上采样路径之间连接对应的层。下面我们查看每个部分:下采样路径(编码器)编码器路径的功能类似于用于分类或特征提取的典型卷积神经网络。它接收输入(带噪声图像$x_t$以及时间步信息$t$,我们将在下一节讨论如何整合这些信息),并通过一系列层对其进行处理。每个层通常包含:卷积层: 这些层应用学习到的滤波器从输入中提取特征。通常每个层使用两个或更多卷积层。激活函数(如ReLU或GeLU)在卷积后应用。下采样操作: 在特定空间分辨率下进行特征提取后,下采样层(例如,最大池化或步幅卷积)会减小特征图的高度和宽度(通常是2倍),同时通常会增加特征通道的数量(深度)。编码器的目的是逐渐降低空间分辨率,同时增加学习到特征的语义复杂度。通过下采样,网络在更深层获得更大的感受野,使其能够捕获输入图像较广区域的背景信息。这对于理解整体结构和内容是必要的,有助于预测合适的噪声模式。瓶颈这是“U”形中的最低点,连接编码器和解码器路径。它通常包含一个或多个卷积层。瓶颈以高度压缩、低空间分辨率、高层特征表示形式表示输入图像。它捕获编码器学习到的最显著、最抽象的信息。上采样路径(解码器)解码器路径逐步增加特征图的空间分辨率,使其恢复到原始输入大小,最终生成预测的噪声图$\epsilon_\theta$。解码器中的每个层通常包括:上采样操作: 这会增加特征图的高度和宽度(例如,使用转置卷积,有时称为“反卷积”,或更简单的方法,如最近邻上采样后跟标准卷积)。通过跳跃连接进行拼接: 这是一个独特之处。上采样的特征图沿通道维度与编码器路径中对应的特征图(在相同空间分辨率下)进行拼接。我们将在下面讨论其重要性。卷积层: 类似于编码器,卷积层(带激活)用于处理拼接后的特征图,优化此分辨率下的表示。解码器本质上通过逐步组合从瓶颈传递上来的高层信息与编码器跳跃连接提供的细粒度、高分辨率特征,来重建详细的噪声图。跳跃连接跳跃连接是直接的链接,将特征图从下采样路径(编码器)中的层传递到上采样路径(解码器)中对应的层。“对应”通常指具有相同空间分辨率的层。为什么这些对于扩散模型中的噪声预测如此重要?保持空间细节: 下采样过程本身会损失细粒度空间信息。通过将高分辨率特征图直接从编码器传递到解码器,跳跃连接使网络能够重用输入$x_t$中精确的空间细节。这对于预测准确对应噪声图像中存在结构噪声极为重要,从而在去噪后生成更清晰、更详细的图像。梯度流动: 它们有助于缓解训练期间的梯度消失问题,使梯度能更直接地流回网络的早期层。如果没有跳跃连接,解码器将只能接收来自高度压缩的瓶颈表示的信息,这将使重建一个精确的、像素级的、符合原始图像结构的噪声图变得非常困难。digraph G { rankdir="TB"; splines=ortho; // Use orthogonal lines for cleaner look node [shape=box, style=filled, fontname="Helvetica", fontsize=10, width=2, height=0.6]; edge [fontname="Helvetica", fontsize=9]; subgraph cluster_main { label = "用于噪声预测的U-Net架构"; bgcolor="#f8f9fa"; // Light background for the whole thing graph[style=dotted]; // Add border to main cluster // 输入 input_node [label="输入\n(带噪声数据 x_t, 时间步 t)", shape=ellipse, fillcolor="#adb5bd"]; // 编码器路径 subgraph cluster_encoder { label = "编码器路径"; bgcolor="#e9ecef"; node [fillcolor="#a5d8ff"]; // Blueish nodes graph[style=dashed]; enc1 [label="第1层\n卷积层"]; enc_pool1 [label="下采样", shape=invtrapezium, fillcolor="#74c0fc"]; enc2 [label="第2层\n卷积层"]; enc_pool2 [label="下采样", shape=invtrapezium, fillcolor="#74c0fc"]; enc3 [label="第3层\n卷积层"]; enc_pool3 [label="下采样", shape=invtrapezium, fillcolor="#74c0fc"]; } // 瓶颈 subgraph cluster_bottleneck { label = "瓶颈"; bgcolor="#e9ecef"; node [fillcolor="#b2f2bb"]; // Greenish node graph[style=dashed]; bn [label="瓶颈\n卷积层"]; } // 解码器路径 subgraph cluster_decoder { label = "解码器路径"; bgcolor="#e9ecef"; node [fillcolor="#ffd8a8"]; // Orange nodes graph[style=dashed]; dec_up3 [label="上采样", shape=trapezium, fillcolor="#ffc078"]; dec3 [label="第3层\n拼接 + 卷积层"]; dec_up2 [label="上采样", shape=trapezium, fillcolor="#ffc078"]; dec2 [label="第2层\n拼接 + 卷积层"]; dec_up1 [label="上采样", shape=trapezium, fillcolor="#ffc078"]; dec1 [label="第1层\n拼接 + 卷积层"]; } // 输出 output_conv [label="最终卷积层\n(例如, 1x1 卷积)", fillcolor="#adb5bd"]; output_node [label="输出\n(预测噪声 \u03b5_\u03b8)", shape=ellipse, fillcolor="#adb5bd"]; // 连接 input_node -> enc1 [color="#495057"]; // 编码器路径流 enc1 -> enc_pool1 [color="#495057"]; enc_pool1 -> enc2 [color="#495057"]; enc2 -> enc_pool2 [color="#495057"]; enc_pool2 -> enc3 [color="#495057"]; enc3 -> enc_pool3 [color="#495057"]; enc_pool3 -> bn [color="#495057"]; // 解码器路径流 bn -> dec_up3 [color="#495057"]; dec_up3 -> dec3 [color="#495057"]; dec3 -> dec_up2 [color="#495057"]; dec_up2 -> dec2 [color="#495057"]; dec2 -> dec_up1 [color="#495057"]; dec_up1 -> dec1 [color="#495057"]; // 跳跃连接(虚线,不同颜色) edge [style=dashed, color="#7950f2", constraint=false]; // Violet color for skips enc3 -> dec3 [label=" 跳跃连接"]; enc2 -> dec2 [label=" 跳跃连接"]; enc1 -> dec1 [label=" 跳跃连接"]; // 最终输出连接 edge [style=solid, color="#495057", constraint=true]; // Reset edge style dec1 -> output_conv; output_conv -> output_node; } }图示U-Net结构。箭头表示数据流向。编码器路径逐步减小空间维度,而解码器路径则增加空间维度。跳跃连接(紫色虚线)将高分辨率特征从编码器传递到解码器。总而言之,U-Net架构有效地结合了对背景信息的理解(通过编码器和瓶颈)和精确的空间定位(由解码器和跳跃连接实现)。这使其非常适合预测一个与输入带噪声图像$x_t$具有相同维度,并准确反映对应图像内容和时间步$t$所指示噪声水平的噪声图$\epsilon_\theta$的任务。U-Net的最后一层通常是一个卷积(例如1x1或3x3),它将特征表示映射到所需的输出通道数量(例如,RGB图像噪声的3个通道)。