趋近智
条件生成对抗网络(cGAN)允许通过提供额外信息来引导生成器的输出,这些信息通常是类别标签或其他属性,表示为 y。构建和训练cGAN涉及主要实用步骤,着重讲解如何将此条件输入 y 整合到生成器和判别器网络中。
我们将使用熟悉的MNIST数据集作为示例。该数据集包含手写数字(0-9)的灰度图像,因此数字标签很自然地成为我们的条件 y。我们的目标是训练一个生成器,使其在接收到相应标签时,能够生成特定数字的图像。
首先,使用您偏好的深度学习框架工具加载数据集(例如MNIST)。与标准GAN训练只需图像 x 不同,对于cGAN,我们还需要对应的标签 y。请确保您的数据加载器提供 (x,y) 对。
标签 y 通常是整数(MNIST为0到9)。由于神经网络最适合处理数值向量,我们需要将这些整数标签转换为合适的格式。一种常用且有效的方法是使用嵌入层。我们可以将每个标签表示为一个可学习的向量。此外,对于像MNIST这样的离散标签,独热编码是一个直接的选项,尽管嵌入通常提供更大的灵活性和潜在更好的表现,尤其是在类别数量较多时。
我们假设使用嵌入。如果我们有 Nc 个类别,我们可以创建一个嵌入层,将每个整数标签 i∈{0,1,...,Nc−1} 映射到所选维度(例如 de)的密集向量。
生成器 G 现在必须接受两个输入:随机噪声向量 z 和条件信息 y。主要想法是有效地结合这些输入,以便生成器学习利用 y 来形成其输出。
以下是一个结构示例(PyTorch风格):
# 生成器结构
class ConditionalGenerator(nn.Module):
def __init__(self, noise_dim, num_classes, embedding_dim, output_channels):
super().__init__()
self.label_embedding = nn.Embedding(num_classes, embedding_dim)
# 定义生成器网络主体
# 第一层的输入维度应适应 noise_dim + embedding_dim
self.main = nn.Sequential(
# 示例:转置卷积层、批归一化、ReLU
# nn.ConvTranspose2d(noise_dim + embedding_dim, ...)
# ... 其他层 ...
# nn.ConvTranspose2d(..., output_channels, ..., bias=False),
# nn.Tanh() # 输出激活通常使用Tanh,用于将图像缩放到[-1, 1]
)
def forward(self, noise, labels):
# 嵌入标签
label_embedding_vector = self.label_embedding(labels) # 形状: (batch_size, embedding_dim)
# 如果需要,重塑嵌入并与噪声拼接
# 假设噪声形状为 (batch_size, noise_dim, 1, 1) 以用于 ConvTranspose2d
# 我们需要重塑 label_embedding_vector 以在空间上匹配
label_embedding_reshaped = label_embedding_vector.view(label_embedding_vector.size(0), label_embedding_vector.size(1), 1, 1)
# 沿通道维度拼接
combined_input = torch.cat([noise, label_embedding_reshaped], dim=1) # 形状: (batch_size, noise_dim + embedding_dim, 1, 1)
# 生成图像
generated_image = self.main(combined_input)
return generated_image
类似地,判别器 D 现在不仅要评估图像 x,还要评估对 (x,y)。它需要判断图像 x 是否是与标签 y 对应的真实图像,或是为标签 y 生成的假图像。
以下是使用后期拼接的结构示例(PyTorch风格):
# 判别器结构
class ConditionalDiscriminator(nn.Module):
def __init__(self, num_classes, embedding_dim, input_channels):
super().__init__()
self.label_embedding = nn.Embedding(num_classes, embedding_dim)
# 定义图像处理部分(例如,卷积层)
self.image_processor = nn.Sequential(
# 示例:卷积层、批归一化、LeakyReLU
# nn.Conv2d(input_channels, ...)
# ... 其他卷积层 ...
)
# 定义最终分类器部分
# 输入维度需要适应展平的图像特征 + embedding_dim
self.classifier = nn.Sequential(
# 示例:展平、线性层、LeakyReLU
# nn.Flatten(),
# nn.Linear(feature_dim + embedding_dim, ...)
# nn.LeakyReLU(0.2, inplace=True),
# nn.Linear(..., 1) # 输出层(如果使用 BCEWithLogitsLoss 或 Wasserstein 损失,则不使用 sigmoid)
)
# 根据 image_processor 输出形状计算 feature_dim
def forward(self, image, labels):
# 处理图像
image_features = self.image_processor(image) # 形状取决于层
image_features_flat = image_features.view(image_features.size(0), -1) # 展平特征
# 嵌入标签
label_embedding_vector = self.label_embedding(labels) # 形状: (batch_size, embedding_dim)
# 拼接展平特征和标签嵌入
combined_input = torch.cat([image_features_flat, label_embedding_vector], dim=1)
# 分类
validity = self.classifier(combined_input)
return validity
下图显示了cGAN中的数据流,突出显示了条件标签 y 被整合的位置。
条件生成对抗网络中的数据流。条件 y(黄色)在生成器中被嵌入并与噪声 z(蓝色)组合,在判别器中与图像特征(绿色)组合。判别器根据图像及其假定条件输出一个判断(红色)。
目标函数仍然是最小最大博弈,但现在 D 和 G 也依赖于 y。价值函数 V(D,G) 为:
GminDmaxV(D,G)=E(x,y)∼pdata(x,y)[logD(x,y)]+Ez∼pz(z),y∼py(y)[log(1−D(G(z,y),y))]此处,pdata(x,y) 是真实数据和标签的联合分布,py(y) 是标签的分布(我们通常均匀采样或根据训练集分布进行采样)。
实际操作中,当使用标准二元交叉熵损失(通常用 BCEWithLogitsLoss 实现以增加稳定性)时,判别器会尝试为真实对 (x,y) 输出高值,为虚假对 (G(z,y),y) 输出低值。生成器通过使 D(G(z,y),y) 输出高值来试图欺骗判别器。请记住,在生成 G(z,y) 并将其传递给判别器时,要使用相同的标签 y。
cGAN训练循环遵循标准GAN模式,其中特别增加了对标签 y 的处理:
更新判别器:
.detach()。更新生成器:
.detach()。重复这些步骤达到所需的训练轮次。请记住标准的GAN训练方法,例如使用合适的优化器(例如Adam)、学习率,以及如果需要,使用第三章中讨论的稳定技术。
训练完成后,您可以生成以特定标签为条件的图像。只需:
例如,要只生成数字“7”的图像,您将使用不同的噪声向量 z 反复调用 G(z,标签=7)。
衡量cGAN不仅包括评估生成图像的质量和多样性(使用第五章讨论的FID或IS等指标),还包括条件一致性。生成器是否生成了实际匹配所请求标签 y 的图像?这可以通过目视检查进行定性确认,也可以通过将生成的图像 G(z,y) 输入到预训练分类器(独立于cGAN的判别器)并测量其预测 y 的准确性来进行定量检查。
本次实操练习为实现cGAN提供了蓝图。通过将条件信息细致地整合到生成器和判别器中,您将对生成过程获得很大控制,使得能够基于特定属性进行目标合成。尝试不同的嵌入维度和拼接策略,观察它们如何影响您所选数据集上的表现。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造