While the Wasserstein GAN (WGAN) introduced a more stable loss function based on the Earth Mover's distance, its original method for enforcing the necessary 1-Lipschitz constraint on the discriminator, weight clipping, came with its own set of problems. Clipping the weights to a small range (e.g., [−0.01,0.01]) can lead to optimization difficulties: if the clipping range is too small, gradients can vanish, hindering learning; if it's too large, gradients might explode, leading to instability. Furthermore, weight clipping biases the discriminator towards learning overly simple functions, potentially reducing its capacity to capture the complexity of the real data distribution.
To address these limitations, an improved technique called the Gradient Penalty (WGAN-GP) was proposed. Instead of crudely forcing weights into a box, WGAN-GP directly penalizes the norm of the discriminator's gradient with respect to its input, encouraging it to stay close to 1. This acts as a softer, more targeted way to enforce the Lipschitz constraint.
The Gradient Penalty Term
The core idea is to add a penalty term to the discriminator's loss function. This penalty is designed to push the L2 norm (Euclidean norm) of the discriminator's gradient towards 1, specifically for points sampled between the real and generated data distributions.
Mathematically, the gradient penalty term is defined as:
λEx^∼Px^[(∥∇x^D(x^)∥2−1)2]
Let's break this down:
D(x^): The output of the discriminator (critic) for an input sample x^.
∇x^D(x^): The gradient of the discriminator's output with respect to its input x^. This tells us how sensitive the discriminator's output is to changes in the input.
∥⋅∥2: The L2 norm (Euclidean norm) of the gradient vector. For a vector v=(v1,v2,...,vn), ∥v∥2=v12+v22+...+vn2.
(∥∇x^D(x^)∥2−1)2: This term calculates the squared difference between the gradient's norm and 1. It becomes zero when the norm is exactly 1 and increases quadratically as the norm deviates from 1. This penalizes gradients whose norms are far from 1.
x^∼Px^: This indicates that the expectation is taken over samples x^ drawn from a specific distribution Px^. These samples x^ are generated by interpolating between pairs of real samples (x∼Pdata) and generated samples (x~∼Pg). Specifically:
x^=ϵx+(1−ϵ)x~
where ϵ is a random number sampled uniformly from [0,1].
λ: A hyperparameter that controls the strength of the gradient penalty relative to the original WGAN loss. A common value is λ=10.
The following diagram illustrates the sampling of interpolated points x^:
Interpolation between a real sample x and a generated sample x~ to create x^. The gradient penalty is evaluated at these interpolated points x^.
Why Interpolated Samples?
The Wasserstein distance theory requires the discriminator (critic) to be 1-Lipschitz everywhere. Enforcing this globally is computationally difficult. The WGAN-GP paper demonstrated empirically that enforcing the constraint only along these straight lines between real and generated samples is sufficient to achieve stable training. Intuitively, this focuses the constraint on the regions of the input space that are currently relevant to the generator's learning process.
The penalty function (∥g∥2−1)2 encourages the gradient norm ∥g∥2 to be close to 1, as shown below:
The gradient penalty (∥g∥2−1)2 plotted against the gradient norm ∥g∥2. The penalty is minimized (zero) when the norm is exactly 1, incentivizing the discriminator to satisfy this condition.
The WGAN-GP Objective Functions
With the gradient penalty added, the discriminator's objective is to maximize LD:
Note that we minimize the negative of this loss during implementation. The generator's objective remains the same as in the original WGAN, aiming to minimize LG (which maximizes the discriminator's score for fake samples):
LG=−Ex~∼Pg[D(x~)]
Advantages of Gradient Penalty
Using the gradient penalty offers several significant advantages over weight clipping:
Improved Stability: It generally leads to more stable training convergence compared to weight clipping, avoiding the specific failure modes associated with poorly chosen clipping parameters.
Higher Model Capacity: By not restricting weights directly, the discriminator can learn more complex functions, potentially leading to a better approximation of the Wasserstein distance and higher-quality generated samples.
No Hyperparameter Tuning for Clipping: Eliminates the need to tune the clipping range c, replacing it with the penalty coefficient λ, which is often less sensitive (typically λ=10 works well).
Implementation Notes
Computational Cost: Calculating the gradient penalty requires computing gradients of gradients (∇x^D(x^)). This involves a second backward pass through the part of the computation graph that calculates the discriminator's output from x^, adding computational overhead compared to standard GANs or WGAN with weight clipping. Most deep learning frameworks provide utilities to compute these higher-order gradients efficiently.
Normalization: The original WGAN-GP paper suggests avoiding Batch Normalization in the discriminator, as it introduces dependencies between samples within a batch, which can interfere with the gradient penalty calculation (which assumes independent samples). Alternative normalization techniques like Layer Normalization, Instance Normalization, or removing normalization altogether might be preferred.
Sampling x^: For each batch, you need to sample ϵ∼U[0,1] (usually one ϵ per sample pair, or sometimes one ϵ broadcast across the batch) and compute the interpolated samples x^.
WGAN-GP represents a significant advancement in stabilizing GAN training. By replacing weight clipping with a theoretically motivated gradient penalty, it allows for training deeper and more complex GANs, capable of generating higher fidelity results while mitigating many of the optimization problems that plagued earlier methods. It has become a standard technique used in many subsequent state-of-the-art GAN architectures.