Wasserstein GAN (WGAN) 和梯度惩罚 (GP) 技术旨在解决原始 WGAN 权重裁剪的局限性,从而稳定训练。WGAN-GP 的实际实现如下。这种方法被普遍认为是 GAN 训练稳定性和样本质量上的显著改进。本实践指南假定您对在 PyTorch 或 TensorFlow 中实现基础 GAN 已感到熟悉。我们将侧重于 WGAN-GP 所需的具体修改。WGAN-GP 的核心组成部分实现 WGAN-GP 主要涉及对判别器网络、损失函数和训练循环的调整。判别器架构: WGAN-GP 中的判别器不包含最终的 Sigmoid 激活函数。其作用是输出一个标量分数(根据 Wasserstein 距离近似值代表的“真实度”),而非概率。输出层应为线性层。损失函数: 我们将标准 GAN 的对数损失替换为源自 Wasserstein 距离估计和梯度惩罚的损失。梯度惩罚: 这是其显著特点。不同于权重裁剪,我们向判别器的损失中添加了一个惩罚项,促使判别器对其输入的梯度范数接近 1。这更有效地强制执行 Lipschitz 约束。训练循环: 通常,判别器在每个训练迭代中的更新频率高于生成器(例如,每次生成器更新对应 5 次判别器更新)。实现判别器损失判别器旨在最大化其对真实样本和生成样本的得分差异,同时包含梯度惩罚。判别器 ($D$) 的损失函数为:$$ L_D = \mathbb{E}{\tilde{x} \sim P_g}[D(\tilde{x})] - \mathbb{E}{x \sim P_r}[D(x)] + \lambda \mathbb{E}{\hat{x} \sim P{\hat{x}}}[(|\nabla_{\hat{x}} D(\hat{x})|_2 - 1)^2] $$其中:$P_g$ 是生成器分布(伪造样本 $\tilde{x}$)。$P_r$ 是真实数据分布(真实样本 $x$)。$P_{\hat{x}}$ 是插值样本 $\hat{x}$ 的分布。$\lambda$ 是梯度惩罚系数(通常设为 10)。让我们分解实现过程,特别是梯度惩罚项。实现梯度惩罚计算梯度惩罚包含以下几个步骤:采样插值点: 对于批次中的每个真实样本 $x$ 和生成样本 $\tilde{x}$,创建一个插值样本 $\hat{x}$。 $$ \hat{x} = \epsilon x + (1 - \epsilon) \tilde{x} $$ 这里,$\epsilon$ 是从 $U[0, 1]$ 均匀采样的随机数。这需要对批次进行逐元素操作。计算插值点的判别器输出: 将这些插值样本 $\hat{x}$ 输入判别器网络,以获取其得分 $D(\hat{x})$。计算梯度: 计算判别器输出 $D(\hat{x})$ 相对于插值输入 $\hat{x}$ 的梯度。这需要使用您的深度学习框架的自动微分功能(例如 PyTorch 中的 torch.autograd.grad 或 TensorFlow 中的 tf.GradientTape)。很重要的一点是,要确保为输入计算梯度(PyTorch 中需要 create_graph=True,因为梯度惩罚本身就是损失图的一部分)。计算梯度范数: 针对每个插值样本,计算这些梯度的 L2 范数(欧几里得范数)。计算惩罚: 将每个样本的惩罚计算为 $( | \nabla_{\hat{x}} D(\hat{x}) |_2 - 1 )^2$。平均并缩放: 对批次中的惩罚进行平均,并乘以系数 $\lambda$。以下是梯度惩罚函数的一个 PyTorch 实现片段:import torch import torch.autograd as autograd def compute_gradient_penalty(critic, real_samples, fake_samples, device): """计算 WGAN GP 的梯度惩罚损失""" batch_size = real_samples.size(0) # 用于真实样本和伪造样本之间插值的随机权重项 alpha = torch.rand(batch_size, 1, 1, 1, device=device) # 假设为 4D 张量 (B, C, H, W) # 扩展 alpha 以匹配图像维度 alpha = alpha.expand_as(real_samples) # 获取真实样本和伪造样本之间的随机插值 interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) # 获取插值点的判别器得分 d_interpolates = critic(interpolates) # 使用全一张量作为梯度计算的目标 fake = torch.ones(batch_size, 1, device=device, requires_grad=False) # 使用匹配判别器输出的尺寸 # 获取相对于插值点的梯度 gradients = autograd.grad( outputs=d_interpolates, inputs=interpolates, grad_outputs=fake, # 梯度输出必须匹配 d_interpolates 的形状 create_graph=True, # 为二阶导数创建计算图 (GP 损失的一部分) retain_graph=True, # 保留计算图以进行后续计算 (判别器损失) only_inputs=True, )[0] # 重塑梯度以便于计算每个样本的范数 gradients = gradients.view(gradients.size(0), -1) # 计算 L2 范数和惩罚 gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() return gradient_penalty # --- 训练循环内部 --- # 假设已定义 'critic', 'real_imgs', 'fake_imgs' # LAMBDA_GP = 10 # 梯度惩罚系数 # gradient_penalty = compute_gradient_penalty(critic, real_imgs.data, fake_imgs.data, device) # critic_loss = torch.mean(critic_fake) - torch.mean(critic_real) + LAMBDA_GP * gradient_penalty # critic_loss.backward() # optimizer_D.step()注意: 确保用于 alpha、fake 和梯度计算的形状与您的具体数据和判别器输出维度匹配。interpolates 上的 requires_grad_(True) 以及 autograd.grad 中的 create_graph=True、retain_graph=True 对于正确计算惩罚是必不可少的。实现生成器损失生成器 ($G$) 旨在生成判别器给予高分的样本(即让判别器认为它们是真实的)。其损失函数更简单:$$ L_G = - \mathbb{E}_{\tilde{x} \sim P_g}[D(\tilde{x})] $$在实践中,这意味着生成一批伪造样本,将它们输入判别器,并最小化所得分数的负平均值。# --- 训练循环内部,生成器更新阶段 --- # 生成伪造图像 # z = torch.randn(batch_size, latent_dim, 1, 1, device=device) # gen_imgs = generator(z) # 计算生成器损失 # fake_scores = critic(gen_imgs) # generator_loss = -torch.mean(fake_scores) # generator_loss.backward() # optimizer_G.step()训练过程典型的 WGAN-GP 训练循环涉及判别器和生成器之间的交替更新。常见做法是每次生成器更新执行多次判别器更新。digraph WGAN_GP_Training { rankdir=TB; node [shape=box, style=rounded, fontname="sans-serif", margin=0.2, color="#adb5bd", fontcolor="#495057"]; edge [fontname="sans-serif", fontsize=10, color="#495057"]; Start [label="开始迭代", shape=ellipse, style=filled, fillcolor="#a5d8ff"]; Loop_Crit [label="判别器迭代 n_critic 步:"]; Update_Crit [label="更新判别器:\n1. 采样真实批次 (x)\n2. 采样噪声 (z),生成伪造批次 (~x)\n3. 计算 D(x), D(~x)\n4. 计算梯度惩罚 (GP)\n5. 计算判别器损失 = D(~x) - D(x) + \u03bb * GP\n6. 反向传播并优化判别器", style=filled, fillcolor="#ffec99"]; Update_Gen [label="更新生成器:\n1. 采样噪声 (z),生成伪造批次 (~x)\n2. 计算 D(~x)\n3. 计算生成器损失 = -D(~x)\n4. 反向传播并优化生成器", style=filled, fillcolor="#b2f2bb"]; End_Crit_Loop [label="结束判别器循环"]; Next_Iter [label="下一次迭代 / 结束训练", shape=ellipse, style=filled, fillcolor="#a5d8ff"]; Start -> Loop_Crit; Loop_Crit -> Update_Crit [label=" 判别器更新"]; Update_Crit -> Loop_Crit [label=" 重复 n_critic 次"]; Loop_Crit -> End_Crit_Loop [label=" n_critic 步后"]; End_Crit_Loop -> Update_Gen [label=" 生成器更新 (1 步)"]; Update_Gen -> Next_Iter; }WGAN-GP 的训练循环结构,侧重于每次生成器更新进行多次判别器更新。注意事项:优化器: 常用 Adam,通常使用特定的超参数,如 $\beta_1=0.0$ 或 $\beta_1=0.5$ 以及 $\beta_2=0.9$。标准的 Adam 设置($\beta_1=0.9, \beta_2=0.999$)也可以,但可能需要更多调整。为生成器和判别器使用独立的优化器实例。学习率: 两者相似的学习率(例如 1e-4 或 2e-4)通常是一个不错的起始点,不同于 TTUR 明确使用不同的学习率。判别器更新次数 ($n_{critic}$): 像 5 这样的值很常见,但这可以调整。它确保判别器为生成器提供可靠的梯度。梯度惩罚系数 ($\lambda$): 通常设为 10,如果训练不稳定或梯度消失/爆炸,可以调整。批归一化: 在 WGAN-GP 的判别器中通常避免使用批归一化,因为它可能在批次中的样本之间引入依赖关系,干扰梯度惩罚的计算。如果需要归一化,层归一化或实例归一化可能是替代方案。对于生成器,批归一化通常仍在使用。通过实现这些组成部分,特别是梯度惩罚的计算和调整后的损失函数,您可以运用 WGAN-GP 训练更稳定的 GAN,能够生成更高质量的合成数据,与标准 GAN 公式或原始带有权重裁剪的 WGAN 相比。请记住在训练期间监控判别器损失、生成器损失和梯度惩罚的量级,以诊断潜在问题。