As we've discussed, the standard Evidence Lower Bound (ELBO) is central to training VAEs. However, the ELBO is, by definition, a lower bound on the log marginal likelihood logpθ(x). The gap between the ELBO and the true log-likelihood is precisely the KL divergence KL(qϕ(z∣x)∣∣pθ(z∣x)). A simpler approximate posterior qϕ(z∣x) often leads to a larger gap, meaning the bound is looser. While making qϕ(z∣x) more expressive (e.g., using normalizing flows, which we'll cover) is one way to tighten this bound, Importance Weighted Autoencoders (IWAEs) offer an alternative: achieve a tighter bound by using multiple samples from the same qϕ(z∣x).
The IWAE Objective: A Tighter Bound with Multiple Samples
The core idea behind IWAEs, introduced by Burda, Grosse, and Salakhutdinov, is to leverage multiple samples from the approximate posterior qϕ(z∣x) to form a better Monte Carlo estimate of pθ(x). Recall that the marginal likelihood can be written as pθ(x)=Ez∼qϕ(z∣x)[qϕ(z∣x)pθ(x,z)].
The standard ELBO is Ez∼qϕ(z∣x)[logqϕ(z∣x)pθ(x,z)]. Jensen's inequality tells us that logE[Y]≥E[logY], which is why the ELBO is a lower bound.
The IWAE objective, denoted LK(x), uses K samples z1,…,zK drawn independently from qϕ(z∣x) for a given input x:
LK(x)=Ez1,…,zK∼qϕ(z∣x)[log(K1k=1∑Kqϕ(zk∣x)pθ(x,zk))]
Each term wk=qϕ(zk∣x)pθ(x,zk) is an importance weight. The IWAE objective averages these weights before taking the logarithm. This is a subtle but significant difference from the standard ELBO.
The LK objective has several desirable properties:
- Monotonically Improving Bound: For any K≥1, we have:
logpθ(x)≥LK(x)≥LK−1(x)
This means that as you increase the number of samples K, the IWAE objective provides a progressively tighter lower bound on the true log marginal likelihood.
- Connection to ELBO: For K=1, the IWAE objective L1(x) is exactly the standard ELBO:
L1(x)=Ez1∼qϕ(z∣x)[log(qϕ(z1∣x)pθ(x,z1))]
- Convergence to True Log-Likelihood: As K→∞, the IWAE objective LK(x) converges to the true log marginal likelihood logpθ(x). This is because K1∑k=1Kwk becomes a better estimate of pθ(x).
The diagram below illustrates the general process of calculating the IWAE objective using K samples.
Multiple samples z1,…,zK are drawn from the inference network qϕ(z∣x). Each sample contributes an importance weight wk=pθ(x∣zk)p(zk)/qϕ(zk∣x). These weights are averaged, and the logarithm of this average forms the IWAE objective LK.
Intuition: Why Averaging Weights First Matters
The standard ELBO (L1) can be thought of as averaging log(wk). The IWAE objective (LK) averages wk first, then takes the logarithm: log(avg(wk)). Due to Jensen's inequality (logE[Y]≥E[logY]), this change directly leads to a tighter bound.
Intuitively, qϕ(z∣x) might be a poor approximation to the true posterior pθ(z∣x). It might assign low probability density to regions where pθ(z∣x) is high. If we only take one sample z1 (as in the standard ELBO computation), and it lands in such a poorly estimated region, we get a bad estimate for that data point.
With IWAE, if even one of the K samples zk happens to fall into a region of high true posterior probability (and thus gets a large importance weight wk), it can significantly lift the average K1∑wj. This makes the overall estimate less susceptible to any single "unlucky" sample from qϕ(z∣x). The inference network qϕ(z∣x) is still "simple", but the multi-sample estimation of the bound is more robust.
Benefits of Using IWAE
Training a VAE by maximizing LK (for K>1) instead of the standard ELBO (L1) offers several advantages:
- Tighter Log-Likelihood Bounds: This is the most direct benefit. IWAEs typically report better (i.e., higher) log-likelihood estimates on test data compared to VAEs trained with the ELBO, especially when using the same K for evaluation.
- Improved Model Learning: Optimizing a tighter bound can lead to better parameter updates for both the generative model pθ(x∣z) (decoder) and the inference network qϕ(z∣x) (encoder). It has been observed that models trained with IWAE can learn richer, more informative latent representations and generate higher quality samples. The encoder qϕ(z∣x) might learn to propose samples zk that are more diverse or better cover the modes of the true posterior, as these will contribute more effectively to a high LK value.
- No Change to Model Architecture: IWAE achieves these benefits without requiring a more complex inference network architecture (like structured VI or normalizing flows would). The expressiveness of qϕ(z∣x) remains the same; it's the objective function that changes.
Computational Cost and Considerations
The primary drawback of IWAE is its increased computational cost during training and evaluation.
For each data point and each gradient step:
- You need to sample K latent variables zk from qϕ(z∣x).
- You need to perform K forward passes through the decoder to compute pθ(x∣zk) for each zk.
- You need to compute qϕ(zk∣x) for each zk.
This means the computational cost scales linearly with K. If K=50, a training epoch will take roughly 50 times longer than for a standard VAE, assuming the decoder is the bottleneck.
Choice of K:
The number of samples K is a hyperparameter.
- Small K (e.g., 5 to 10): Offers a modest improvement in the bound with a manageable increase in computation.
- Large K (e.g., 50, 100, or even 5000 for evaluation): Provides a much tighter bound, approaching the true log-likelihood, but at a significant computational cost.
In practice, training is often done with a moderate K (e.g., K=5 or K=50), and evaluation might use a larger K for a more accurate log-likelihood estimate.
Gradient Variance:
While LK is a tighter bound, the variance of its gradient estimator can sometimes be an issue, particularly for very large K or when importance weights are highly skewed. However, the benefits of a tighter bound often outweigh this concern, and techniques like using the reparameterization trick for zk are essential for keeping gradient variance manageable.
Training with the IWAE Objective
Training a VAE with the IWAE objective is similar to training a standard VAE, with the main difference being the loss calculation:
- Forward Pass:
- For an input x, obtain the parameters of qϕ(z∣x) (e.g., μ(x),σ(x) if q is Gaussian).
- Draw K samples zk∼qϕ(z∣x) using the reparameterization trick. For each zk:
- Compute logqϕ(zk∣x).
- Pass zk through the decoder to get pθ(x∣zk).
- Compute logp(zk) (from the prior).
- Calculate the unnormalized log-weight: logw~k=logpθ(x∣zk)+logp(zk)−logqϕ(zk∣x).
- Objective Calculation:
- Compute the IWAE objective LK. To avoid numerical issues with very small weights, it's common to use the log-sum-exp trick:
LK(x)=LogSumExp(logw~1,…,logw~K)−logK
where LogSumExp(a1,…,aK)=log∑j=1Kexp(aj).
- Backward Pass: Compute gradients of −LK(x) with respect to θ and ϕ, and update the parameters.
The rest of the VAE machinery, including network architectures for the encoder and decoder, remains largely the same.
Summary and When to Use IWAEs
Importance Weighted Autoencoders provide a principled way to obtain tighter lower bounds on the log-likelihood for VAEs by using multiple importance samples. This often translates into improved model performance, such as better quality generated samples and more meaningful latent representations, without explicitly increasing the complexity of the inference network qϕ(z∣x).
Consider using IWAEs when:
- You need more accurate estimates of the data log-likelihood.
- You suspect that the looseness of the standard ELBO is a limiting factor for your model's performance.
- You have the computational budget to accommodate the K-fold increase in computation per example.
- You want to improve VAE performance without resorting to more complex inference network architectures immediately.
IWAEs represent a significant step up from the basic VAE objective. They demonstrate how rethinking the estimation of the objective itself, rather than just the model components, can lead to substantial gains in generative modeling. Next, we will explore techniques that directly enhance the expressiveness of the approximate posterior qϕ(z∣x).