As we saw, vanilla GAN training often struggles with instability. The generator might produce only a narrow range of outputs (mode collapse), or the learning process might diverge entirely due to vanishing or exploding gradients. A significant contributing factor to these issues lies in the very nature of the objective function used: the Jensen-Shannon (JS) divergence.
While the JS divergence is a valid way to measure the difference between probability distributions, it has properties that make it less than ideal for the adversarial training dynamics of GANs, especially in high-dimensional spaces like images.
The Problem with Jensen-Shannon Divergence
The original GAN paper showed that minimizing the GAN objective corresponds to minimizing the JS divergence between the real data distribution Pr and the generated data distribution Pg. However, consider what happens when Pr and Pg have supports that do not overlap significantly, or when they lie on low-dimensional manifolds within a much higher-dimensional space (a common scenario for image data).
In such cases, it's often possible to find a discriminator that perfectly separates real samples from generated samples. When the discriminator achieves near-perfect accuracy:
- The JS divergence between Pr and Pg effectively saturates at its maximum value (log 2).
- More importantly, the gradients of the loss function with respect to the generator's parameters approach zero.
Think about it: if the discriminator can always tell real from fake, it provides very little useful information back to the generator about how to improve. The generator essentially hits a wall, receiving no meaningful signal to guide its updates. This vanishing gradient problem is a primary cause of training failure and mode collapse, as the generator stops learning effectively.
We need a way to measure the distance between Pr and Pg that provides smoother, more informative gradients even when the distributions are quite different or don't overlap perfectly.
Introducing the Wasserstein Distance (Earth Mover's Distance)
Enter the Wasserstein-1 distance, often referred to as the Earth Mover's Distance (EMD). Imagine Pr and Pg as two different piles of dirt (probability mass). The EMD represents the minimum "cost" required to transform one pile into the other. The cost is defined as the amount of dirt moved multiplied by the distance it is moved.
The Wasserstein distance calculates the minimum cost to transport mass from the generated distribution Pg to match the real distribution Pr.
Crucially, even if the two piles of dirt (distributions) are located far apart (disjoint supports), the cost of moving the dirt provides a meaningful measure of how different they are. Unlike JS divergence, the Wasserstein distance doesn't saturate abruptly and generally provides a smoother metric.
The Kantorovich-Rubinstein Duality
Calculating the EMD directly using its primal definition involves finding an optimal "transport plan" (how much dirt to move from each point in Pg to each point in Pr), which is computationally intractable in most interesting cases. Fortunately, the Kantorovich-Rubinstein duality provides an alternative formulation for the Wasserstein-1 distance:
W1(Pr,Pg)=∥f∥L≤1sup(Ex∼Pr[f(x)]−Ex~∼Pg[f(x~)])
Let's break this down:
- W1(Pr,Pg) is the Wasserstein-1 distance between the real distribution Pr and the generated distribution Pg.
- sup denotes the supremum, which is the least upper bound (similar to a maximum).
- ∥f∥L≤1 means the supremum is taken over all functions f that are 1-Lipschitz continuous. A function f is K-Lipschitz if ∣f(x1)−f(x2)∣≤K∥x1−x2∥ for all x1,x2 in its domain. Essentially, a 1-Lipschitz function's rate of change is bounded by 1. This constraint is fundamental.
- Ex∼Pr[f(x)] is the expected value of f(x) when x is sampled from the real data distribution Pr.
- Ex~∼Pg[f(x~)] is the expected value of f(x~) when x~ is sampled from the generator's distribution Pg (where x~=g(z), with z sampled from some prior noise distribution p(z)).
This dual formulation states that the Wasserstein distance is the maximum possible difference in expectations achievable by evaluating a 1-Lipschitz function f on samples from the two distributions.
Wasserstein GANs (WGANs)
The insight of the WGAN paper (Arjovsky et al., 2017) was to use this dual formulation to define a new GAN objective. The core idea is:
- Replace the Discriminator with a Critic: Instead of a discriminator outputting the probability of a sample being real (a value between 0 and 1, typically via a sigmoid), we use a network, called the critic (let's denote it fw with parameters w), that outputs an unbounded real number.
- Train the Critic to Approximate the Supremum: The critic fw is trained to make the term (Ex∼Pr[fw(x)]−Ex~∼Pg[fw(x~)]) as large as possible. By doing so, under the right conditions (specifically, if fw remains within the space of 1-Lipschitz functions), the critic's objective value approximates the Wasserstein distance W1(Pr,Pg).
- Train the Generator to Minimize the Critic's Output: The generator gθ (with parameters θ) is trained to produce samples x~=gθ(z) that make the critic's output fw(x~) larger (closer to the output for real samples), thereby minimizing the difference the critic is trying to maximize. This effectively minimizes the estimated Wasserstein distance.
The objectives become:
- Critic Loss: Maximize LCritic=Ex∼Pr[fw(x)]−Ez∼p(z)[fw(gθ(z))]. In practice, neural networks are trained via gradient descent (minimization), so we typically minimize −LCritic.
- Generator Loss: Minimize LGenerator=−Ez∼p(z)[fw(gθ(z))]. The generator tries to fool the critic by producing samples that score highly (as if they were real).
Why This Helps Stability
Using the Wasserstein distance via the critic provides several advantages:
- Meaningful Gradients: The Wasserstein distance provides gradients that are non-zero and informative even when Pr and Pg have disjoint supports. This drastically reduces the vanishing gradient problem that plagued original GANs.
- Improved Correlation with Sample Quality: Empirically, the WGAN loss (the estimated Wasserstein distance) often correlates much better with the perceptual quality of the generated samples during training. Lower loss generally means better-looking images, which wasn't always true for the original GAN loss.
- Reduced Risk of Mode Collapse: Theoretical arguments and empirical results suggest that WGANs are less prone to mode collapse, likely due to the more stable gradients allowing the generator to explore the data distribution more effectively.
The Lipschitz Constraint: A Practical Hurdle
The theory behind WGAN relies heavily on the critic fw being (or approximating) a 1-Lipschitz function. If the critic's weights grow too large, it can violate this constraint, making the approximation of the Wasserstein distance invalid and potentially leading back to unstable training.
Therefore, a crucial part of implementing WGANs is enforcing this Lipschitz constraint on the critic network. The original WGAN paper proposed a simple, albeit sometimes problematic, method called weight clipping. Subsequent research introduced more sophisticated and often preferred techniques like the gradient penalty (WGAN-GP) and spectral normalization. We will examine these methods for enforcing the Lipschitz constraint in the following sections.