趋近智
Wasserstein GAN (WGAN) 和梯度惩罚 (GP) 技术旨在解决原始 WGAN 权重裁剪的局限性,从而稳定训练。WGAN-GP 的实际实现如下。这种方法被普遍认为是 GAN 训练稳定性和样本质量上的显著改进。
本实践指南假定您对在 PyTorch 或 TensorFlow 中实现基础 GAN 已感到熟悉。我们将侧重于 WGAN-GP 所需的具体修改。
实现 WGAN-GP 主要涉及对判别器网络、损失函数和训练循环的调整。
判别器旨在最大化其对真实样本和生成样本的得分差异,同时包含梯度惩罚。判别器 (D) 的损失函数为:
LD=Ex~∼Pg[D(x~)]−Ex∼Pr[D(x)]+λEx^∼Px^[(∥∇x^D(x^)∥2−1)2]其中:
让我们分解实现过程,特别是梯度惩罚项。
计算梯度惩罚包含以下几个步骤:
采样插值点: 对于批次中的每个真实样本 x 和生成样本 x~,创建一个插值样本 x^。
x^=ϵx+(1−ϵ)x~这里,ϵ 是从 U[0,1] 均匀采样的随机数。这需要对批次进行逐元素操作。
计算插值点的判别器输出: 将这些插值样本 x^ 输入判别器网络,以获取其得分 D(x^)。
计算梯度: 计算判别器输出 D(x^) 相对于插值输入 x^ 的梯度。这需要使用您的深度学习框架的自动微分功能(例如 PyTorch 中的 torch.autograd.grad 或 TensorFlow 中的 tf.GradientTape)。很重要的一点是,要确保为输入计算梯度(PyTorch 中需要 create_graph=True,因为梯度惩罚本身就是损失图的一部分)。
计算梯度范数: 针对每个插值样本,计算这些梯度的 L2 范数(欧几里得范数)。
计算惩罚: 将每个样本的惩罚计算为 (∥∇x^D(x^)∥2−1)2。
平均并缩放: 对批次中的惩罚进行平均,并乘以系数 λ。
以下是梯度惩罚函数的一个 PyTorch 实现片段:
import torch
import torch.autograd as autograd
def compute_gradient_penalty(critic, real_samples, fake_samples, device):
"""计算 WGAN GP 的梯度惩罚损失"""
batch_size = real_samples.size(0)
# 用于真实样本和伪造样本之间插值的随机权重项
alpha = torch.rand(batch_size, 1, 1, 1, device=device) # 假设为 4D 张量 (B, C, H, W)
# 扩展 alpha 以匹配图像维度
alpha = alpha.expand_as(real_samples)
# 获取真实样本和伪造样本之间的随机插值
interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
# 获取插值点的判别器得分
d_interpolates = critic(interpolates)
# 使用全一张量作为梯度计算的目标
fake = torch.ones(batch_size, 1, device=device, requires_grad=False) # 使用匹配判别器输出的尺寸
# 获取相对于插值点的梯度
gradients = autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=fake, # 梯度输出必须匹配 d_interpolates 的形状
create_graph=True, # 为二阶导数创建计算图 (GP 损失的一部分)
retain_graph=True, # 保留计算图以进行后续计算 (判别器损失)
only_inputs=True,
)[0]
# 重塑梯度以便于计算每个样本的范数
gradients = gradients.view(gradients.size(0), -1)
# 计算 L2 范数和惩罚
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty
# --- 训练循环内部 ---
# 假设已定义 'critic', 'real_imgs', 'fake_imgs'
# LAMBDA_GP = 10 # 梯度惩罚系数
# gradient_penalty = compute_gradient_penalty(critic, real_imgs.data, fake_imgs.data, device)
# critic_loss = torch.mean(critic_fake) - torch.mean(critic_real) + LAMBDA_GP * gradient_penalty
# critic_loss.backward()
# optimizer_D.step()
注意: 确保用于
alpha、fake和梯度计算的形状与您的具体数据和判别器输出维度匹配。interpolates上的requires_grad_(True)以及autograd.grad中的create_graph=True、retain_graph=True对于正确计算惩罚是必不可少的。
生成器 (G) 旨在生成判别器给予高分的样本(即让判别器认为它们是真实的)。其损失函数更简单:
LG=−Ex~∼Pg[D(x~)]在实践中,这意味着生成一批伪造样本,将它们输入判别器,并最小化所得分数的负平均值。
# --- 训练循环内部,生成器更新阶段 ---
# 生成伪造图像
# z = torch.randn(batch_size, latent_dim, 1, 1, device=device)
# gen_imgs = generator(z)
# 计算生成器损失
# fake_scores = critic(gen_imgs)
# generator_loss = -torch.mean(fake_scores)
# generator_loss.backward()
# optimizer_G.step()
典型的 WGAN-GP 训练循环涉及判别器和生成器之间的交替更新。常见做法是每次生成器更新执行多次判别器更新。
WGAN-GP 的训练循环结构,侧重于每次生成器更新进行多次判别器更新。
注意事项:
通过实现这些组成部分,特别是梯度惩罚的计算和调整后的损失函数,您可以运用 WGAN-GP 训练更稳定的 GAN,能够生成更高质量的合成数据,与标准 GAN 公式或原始带有权重裁剪的 WGAN 相比。请记住在训练期间监控判别器损失、生成器损失和梯度惩罚的量级,以诊断潜在问题。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造