具体的代码片段将展示使用PyTorch实现U-Net架构和纳入时间步信息。示例说明了如何实现主要组成部分:U-Net结构本身和正弦时间步编码。实现正弦时间步编码如前所述,网络需要知道它在扩散过程中的哪一个点进行操作。正弦编码提供了一种方法,可以将时间步 $t$ 编码成一个固定大小的向量,网络可以方便地使用它。这些编码的公式涉及正弦和余弦函数,它们应用于时间步,并在不同频率上进行缩放:$$ PE(t, 2i) = \sin(t / 10000^{2i/d}) $$ $$ PE(t, 2i+1) = \cos(t / 10000^{2i/d}) $$这里,$t$ 是时间步,$i$ 是编码维度索引,$d$ 是总编码维度。该方法使模型能够有效地区分不同的时间步。让我们在PyTorch中实现一个函数来生成这些编码:import torch import torch.nn as nn import math class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim # 编码的维度 def forward(self, time): """ 为一批时间步生成正弦编码。 Args: time (torch.Tensor): 形状为 (batch_size,) 的张量,包含时间步。 Returns: torch.Tensor: 形状为 (batch_size, dim) 的张量,包含编码。 """ device = time.device half_dim = self.dim // 2 # 计算频率(10000 的指数) embeddings = math.log(10000) / (half_dim - 1) embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) # 计算正弦和余弦的参数 embeddings = time[:, None] * embeddings[None, :] # 连接正弦和余弦分量 embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) # 如有必要处理奇数维度(尽管 'dim' 通常是偶数) if self.dim % 2 == 1: embeddings = torch.cat([embeddings, torch.zeros_like(embeddings[:, :1])], dim=-1) return embeddings # 示例用法: time_embedding_dim = 128 # 选择一个编码维度 time_embed_module = SinusoidalPosEmb(time_embedding_dim) # 模拟一批时间步 batch_size = 16 timesteps = torch.randint(0, 1000, (batch_size,)) # 示例:1000个扩散步骤 # 生成编码 time_embeddings = time_embed_module(timesteps) print(f"时间步形状: {timesteps.shape}") print(f"生成的编码形状: {time_embeddings.shape}") # 预期输出: # 时间步形状: torch.Size([16]) # 生成的编码形状: torch.Size([16, 128])这个 SinusoidalPosEmb 模块接收一批时间步并生成相应的编码。这些编码随后将被投影并添加到U-Net层中。U-Net基本模块结构U-Net依靠可重用模块进行下采样和上采样。一个典型模块包含卷积、归一化(例如组归一化,它在批大小变化时表现良好)、激活函数(例如SiLU或GeLU)以及可能的残差连接。时间步编码通常被纳入这些模块中。这是一个包含时间步编码集成的残差模块的简化示例:class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels, time_emb_dim, groups=8): super().__init__() # 将时间步编码投影以匹配通道维度 self.time_mlp = nn.Sequential( nn.SiLU(), # Swish 激活 nn.Linear(time_emb_dim, out_channels) ) # 第一个卷积层 self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) self.norm1 = nn.GroupNorm(groups, out_channels) self.act1 = nn.SiLU() # 第二个卷积层 self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) self.norm2 = nn.GroupNorm(groups, out_channels) self.act2 = nn.SiLU() # 残差连接处理(如果输入/输出通道不同) if in_channels != out_channels: self.res_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) else: self.res_conv = nn.Identity() # 无需更改 def forward(self, x, t_emb): """ Args: x (torch.Tensor): 输入特征图 (批次, 输入通道, 高度, 宽度)。 t_emb (torch.Tensor): 时间步编码 (批次, 时间编码维度)。 Returns: torch.Tensor: 输出特征图 (批次, 输出通道, 高度, 宽度)。 """ h = self.conv1(x) h = self.norm1(h) h = self.act1(h) # 纳入时间步编码 time_bias = self.time_mlp(t_emb) # 将 time_bias 重塑为 (批次, 输出通道, 1, 1) 以进行广播 h = h + time_bias[:, :, None, None] h = self.conv2(h) h = self.norm2(h) h = self.act2(h) # 添加残差连接 return h + self.res_conv(x) # 示例用法(需要示例输入和时间编码) in_channels = 32 out_channels = 64 time_emb_dim = 128 img_size = 32 # 示例图像维度 # 创建虚拟输入张量和时间编码 x_sample = torch.randn(batch_size, in_channels, img_size, img_size) t_emb_sample = torch.randn(batch_size, time_emb_dim) # 通常使用预生成的编码 # 实例化并运行模块 block = ResidualBlock(in_channels, out_channels, time_emb_dim) output = block(x_sample, t_emb_sample) print(f"输入形状: {x_sample.shape}") print(f"时间编码形状: {t_emb_sample.shape}") print(f"输出形状: {output.shape}") # 预期输出: # 输入形状: torch.Size([16, 32, 32, 32]) # 时间编码形状: torch.Size([16, 128]) # 输出形状: torch.Size([16, 64, 32, 32])在这个 ResidualBlock 中,时间步编码 t_emb 首先通过一个小型多层感知机 (time_mlp),它由一个激活函数 (SiLU) 和一个线性层组成。这会将编码投影到 out_channels 维度。结果在第一次卷积和归一化后作为偏差添加到特征图 h 中。[:, :, None, None] 部分重塑偏差,使其能够在空间维度(高度和宽度)上进行广播。该模块以第二次卷积、归一化、激活和残差连接 (res_conv(x)) 的添加结束。构成U-Net结构一个完整的U-Net使用这些模块,以及下采样(例如 nn.MaxPool2d 或步幅卷积)和上采样(例如 nn.Upsample 或 nn.ConvTranspose2d)层。跳跃连接将下采样路径中的特征图连接到上采样路径中对应的层。下面是这些组件如何组合的概略示意图。请注意,这是一个简化表示;实际的U-Net通常在每个分辨率级别有多个模块,并可能包含注意力机制。digraph UNet { rankdir=LR; node [shape=box, style="filled", fillcolor="#e9ecef", fontname="Arial"]; edge [fontname="Arial"]; subgraph cluster_input { label = "输入"; style=filled; color="#dee2e6"; Input [label="输入图像\n(批次, C, H, W)", fillcolor="#a5d8ff"]; TimeInput [label="时间步 t\n(批次,)", fillcolor="#ffd8a8"]; } subgraph cluster_time_embed { label = "时间步编码"; style=filled; color="#dee2e6"; TimeEmbed [label="正弦编码", fillcolor="#ffe066"]; TimeMLP [label="时间 MLP", fillcolor="#ffe066"]; TimeInput -> TimeEmbed -> TimeMLP; EmbeddedTime [label="编码后的 t\n(批次, 编码维度)", shape=ellipse, fillcolor="#ffe066"]; TimeMLP -> EmbeddedTime [style=dashed]; } subgraph cluster_down { label = "编码器 (下采样路径)"; style=filled; color="#dee2e6"; node [fillcolor="#bac8ff"]; edge []; EncBlock1 [label="残差模块 1"]; Down1 [label="下采样"]; EncBlock2 [label="残差模块 2"]; Down2 [label="下采样"]; } subgraph cluster_bottleneck { label = "瓶颈"; style=filled; color="#dee2e6"; node [fillcolor="#d0bfff"]; BNBlock [label="残差模块"]; } subgraph cluster_up { label = "解码器 (上采样路径)"; style=filled; color="#dee2e6"; node [fillcolor="#96f2d7"]; edge []; Up1 [label="上采样"]; DecBlock1 [label="残差模块 1\n+ 跳跃连接"]; Up2 [label="上采样"]; DecBlock2 [label="残差模块 2\n+ 跳跃连接"]; } subgraph cluster_output { label = "输出"; style=filled; color="#dee2e6"; FinalConv [label="最终卷积", fillcolor="#ffc9c9"]; Output [label="预测噪声\n(批次, C, H, W)", fillcolor="#ffa8a8"]; FinalConv -> Output; } # Connections Input -> EncBlock1; EncBlock1 -> Down1; Down1 -> EncBlock2; EncBlock2 -> Down2; Down2 -> BNBlock; BNBlock -> Up1; Up1 -> DecBlock1; DecBlock1 -> Up2; Up2 -> DecBlock2; DecBlock2 -> FinalConv; # Skip Connections EncBlock1 -> DecBlock2 [label="跳跃", style=dashed, constraint=false, color="#868e96"]; EncBlock2 -> DecBlock1 [label="跳跃", style=dashed, constraint=false, color="#868e96"]; # Time Embedding Injection EmbeddedTime -> EncBlock1 [style=dotted, color="#f76707", constraint=false]; EmbeddedTime -> EncBlock2 [style=dotted, color="#f76707", constraint=false]; EmbeddedTime -> BNBlock [style=dotted, color="#f76707", constraint=false]; EmbeddedTime -> DecBlock1 [style=dotted, color="#f76707", constraint=false]; EmbeddedTime -> DecBlock2 [style=dotted, color="#f76707", constraint=false]; }图解了用于噪声预测的U-Net架构内的流程。输入图像和时间步编码通过下采样模块(编码器)、一个瓶颈和带跳跃连接的上采样模块(解码器)进行处理。时间步信息通常会被注入到多个模块中。构建一个完整的U-Net涉及堆叠这些 ResidualBlock 组件,添加适当的下采样(例如使用 stride=2 的 nn.Conv2d)和上采样层(例如 nn.Upsample 或 nn.ConvTranspose2d),并实现跳跃连接(通常是沿通道维度简单的张量连接)。主U-Net类的 forward 方法组织此流程,将输入图像 x 和生成的时间步编码 t_emb 传递通过网络结构。"请记住,这是一个基础设置。实现中通常会纳入更复杂的元素,例如注意力层(特别是在较低分辨率下),以及对通道数量和模块深度进行细致的调整,这些对于获得高质量的生成结果非常重要。然而,理解这些基本构成模块为处理扩散模型架构提供了一个扎实的起点。"