In the previous sections, we discussed the theoretical underpinnings of Wasserstein GANs (WGAN) and the Gradient Penalty (GP) technique designed to address the weight clipping limitations of the original WGAN, further stabilizing training. Now, let's translate that theory into practice by implementing WGAN-GP. This approach is widely regarded as a significant improvement for GAN training stability and sample quality.
This practical guide assumes you are comfortable with implementing basic GANs in PyTorch or TensorFlow. We will focus on the specific modifications required for WGAN-GP.
Implementing WGAN-GP involves adjustments primarily to the critic (discriminator) network, the loss functions, and the training loop.
The critic aims to maximize the difference between its scores for real samples and generated samples, while also incorporating the gradient penalty. The loss function for the critic (D) is:
LD=Ex~∼Pg[D(x~)]−Ex∼Pr[D(x)]+λEx^∼Px^[(∥∇x^D(x^)∥2−1)2]Where:
Let's break down the implementation, particularly the gradient penalty term.
Calculating the gradient penalty involves several steps:
Sample Interpolated Points: For each real sample x and generated sample x~ in a batch, create an interpolated sample x^.
x^=ϵx+(1−ϵ)x~Here, ϵ is a random number sampled uniformly from U[0,1]. This needs to be done element-wise for batches.
Calculate Critic Output for Interpolated Points: Pass these interpolated samples x^ through the critic network to get their scores D(x^).
Compute Gradients: Calculate the gradients of the critic's outputs D(x^) with respect to the interpolated inputs x^. This requires using your deep learning framework's automatic differentiation capabilities (e.g., torch.autograd.grad
in PyTorch or tf.GradientTape
in TensorFlow). It's important to ensure gradients are computed for the inputs (create_graph=True
in PyTorch is needed as the gradient penalty itself forms part of the loss graph).
Calculate Gradient Norm: Compute the L2 norm (Euclidean norm) of these gradients for each interpolated sample.
Compute Penalty: Calculate the penalty for each sample as (∥∇x^D(x^)∥2−1)2.
Average and Scale: Average the penalties across the batch and multiply by the coefficient λ.
Here's a conceptual PyTorch implementation snippet for the gradient penalty function:
import torch
import torch.autograd as autograd
def compute_gradient_penalty(critic, real_samples, fake_samples, device):
"""Calculates the gradient penalty loss for WGAN GP"""
batch_size = real_samples.size(0)
# Random weight term for interpolation between real and fake samples
alpha = torch.rand(batch_size, 1, 1, 1, device=device) # Assuming 4D tensor (B, C, H, W)
# Expand alpha to match the image dimensions
alpha = alpha.expand_as(real_samples)
# Get random interpolation between real and fake samples
interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
# Get critic scores for the interpolates
d_interpolates = critic(interpolates)
# Use fake ones tensor as target for gradient computation
fake = torch.ones(batch_size, 1, device=device, requires_grad=False) # Use size matching critic output
# Get gradient w.r.t. interpolates
gradients = autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=fake, # Gradient outputs must match d_interpolates shape
create_graph=True, # Create graph for second derivative (part of GP loss)
retain_graph=True, # Retain graph for further computations (critic loss)
only_inputs=True,
)[0]
# Reshape gradients to easily compute the norm per sample
gradients = gradients.view(gradients.size(0), -1)
# Compute the L2 norm and the penalty
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty
# --- Inside the training loop ---
# Assuming 'critic', 'real_imgs', 'fake_imgs' are defined
# LAMBDA_GP = 10 # Gradient penalty coefficient
# gradient_penalty = compute_gradient_penalty(critic, real_imgs.data, fake_imgs.data, device)
# critic_loss = torch.mean(critic_fake) - torch.mean(critic_real) + LAMBDA_GP * gradient_penalty
# critic_loss.backward()
# optimizer_D.step()
Note: Ensure the shapes used for
alpha
,fake
, and gradient calculations match your specific data and critic output dimensions. Therequires_grad_(True)
oninterpolates
andcreate_graph=True
,retain_graph=True
inautograd.grad
are essential for correctly computing the penalty.
The generator (G) aims to produce samples that the critic scores highly (i.e., makes the critic think they are real). Its loss is simpler:
LG=−Ex~∼Pg[D(x~)]In practice, this means generating a batch of fake samples, passing them through the critic, and minimizing the negative mean of the resulting scores.
# --- Inside the training loop, during generator update ---
# Generate fake images
# z = torch.randn(batch_size, latent_dim, 1, 1, device=device)
# gen_imgs = generator(z)
# Calculate generator loss
# fake_scores = critic(gen_imgs)
# generator_loss = -torch.mean(fake_scores)
# generator_loss.backward()
# optimizer_G.step()
The typical WGAN-GP training loop involves alternating updates between the critic and the generator. A common practice is to perform multiple critic updates for each generator update.
Training loop structure for WGAN-GP, emphasizing multiple critic updates per generator update.
Key Considerations:
By implementing these components, particularly the gradient penalty calculation and the adjusted loss functions, you can leverage WGAN-GP to train more stable GANs capable of generating higher quality synthetic data compared to the standard GAN formulation or the original WGAN with weight clipping. Remember to monitor the critic loss, generator loss, and the magnitude of the gradient penalty during training to diagnose potential issues.
© 2025 ApX Machine Learning