趋近智
让我们将条件生成理论付诸实践。在前面讨论过的架构和训练策略的基础上,本节提供实践指导,关于如何实现一个根据特定条件(例如类别标签)生成图像的系统。这项功能对于需要可控合成的任务(例如生成特定类型的物体或以目标方式扩充数据集)来说是核心。
实现一个类别条件生成对抗网络(cGAN),这是一种常见且有效的方法。主要思想简单而有效:将条件信息(类别标签)作为额外输入提供给生成器和判别器。这会迫使生成器生成与标签相关的图像,并训练判别器验证图像是否真实 且 与其给定标签匹配。
假设你有一个在无标签数据集上训练的标准GAN实现(例如使用PyTorch)。为了使其成为条件模型,我们需要修改数据管道、生成器和判别器。我们将使用像CIFAR-10这样的数据集,它包含图像和相应的类别标签。
首先,确保你的数据加载器提供图像及其整数类别标签。由于神经网络最适合处理连续向量,我们需要将这些整数标签转换为嵌入。嵌入层适用于此。
# 示例参数
num_classes = 10 # 对于CIFAR-10
embedding_dim = 16 # 标签嵌入向量的大小
# 嵌入层
label_embedding = nn.Embedding(num_classes, embedding_dim)
# 在训练循环中:
# real_images, labels = next(data_loader_iter)
# 将标签转换为嵌入
label_input = label_embedding(labels) # 形状: (batch_size, embedding_dim)
embedding_dim 是一个可以调整的超参数。
生成器需要接收随机噪声向量 z 和条件(标签嵌入)。一种常见策略是将它们拼接起来。
# 示例生成器输入
latent_dim = 100
# ... (如上定义 label_embedding)
# 在训练循环中:
noise = torch.randn(batch_size, latent_dim, device=device)
# 生成用于合成的标签(例如,随机标签或特定标签)
gen_labels = torch.randint(0, num_classes, (batch_size,), device=device)
gen_label_input = label_embedding(gen_labels)
# 拼接噪声和标签嵌入
generator_input = torch.cat((noise, gen_label_input), dim=1)
# 通过生成器
fake_images = generator(generator_input)
生成器的第一层现在必须接受 latent_dim + embedding_dim 的输入大小。另外,标签嵌入可以在中间层进行投影并添加或拼接,类似于StyleGAN或注意力机制中使用的技术,从而实现更多控制。
判别器也必须接收条件信息。它需要判断图像是否真实 且 属于其假定的类别。一种常见做法是将标签嵌入与图像数据一同输入。
一种常见技术是嵌入标签,对其进行空间重塑,并将其作为额外通道拼接到图像张量。另一种做法是,先通过初始卷积层处理图像,然后将展平的图像特征与标签嵌入拼接,再将其输入到后面的全连接层。
我们用后一种方法进行说明:
# 假设 'discriminator_features' 从图像中提取特征
# 假设 'label_embedding' 生成标签嵌入
# 在判别器的正向传播中:
image_features = self.feature_extractor(image) # 例如,卷积层的输出
image_features_flat = image_features.view(image.size(0), -1)
# label_input 是传递给判别器的嵌入标签
discriminator_input = torch.cat((image_features_flat, label_input), dim=1)
validity = self.final_layers(discriminator_input) # 最终分类层
判别器在拼接点之后的层需要适应图像特征和标签嵌入的总大小。
训练循环需要小心处理真实和虚假样本的标签:
这是一个简化的图表,说明了cGAN中的信息流:
类别条件GAN中的信息流。条件(标签
y)被嵌入并作为输入提供给生成器(连同噪声z)和判别器(连同真实或生成的图像)。
成功训练后,你应该能够向生成器提供特定的类别标签(例如,用于数字生成的“3”,或用于CIFAR-10的“猫”)以及一个噪声向量,并获得代表该类别的图像。目视检查是第一步:生成根据每个类别条件的图像批次,并检查它们是否合适。
对于定量评估(第5章已讨论),你可以计算 FID 或 IS 等指标,按类别 评估每个类别内的质量和多样性。比较特定类别的生成样本分布与该同一类别的真实样本分布也很有益。
尽管我们侧重于cGAN,但条件生成在扩散模型中也高效。分类器引导是指使用一个单独的预训练分类器在每个去噪步骤中引导采样过程趋向所需类别。一种更现代且通常更受欢迎的技术是无分类器引导(第4章详细介绍)。这包括在条件和无条件输入上交替训练扩散模型。在采样期间,噪声预测基于条件和无条件得分进行外推,从而在不需要外部分类器的情况下实现强引导。实现无分类器引导需要稍微修改U-Net架构以接受条件嵌入(如类别标签或文本嵌入),并调整训练过程以在某些训练步骤中随机丢弃条件。
这项实践练习说明了如何扩展生成模型以实现可控合成。通过加入条件信息,你对输出有了很大的控制,从而能够针对本章中讨论的各种高级应用进行定向数据生成,从特定对象合成到引导式数据扩充。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造