In this chapter, we've explored several methods to enhance the inference capabilities of Variational Autoencoders. A key takeaway is that the quality of the approximate posterior qϕ(z∣x) significantly impacts VAE performance. While amortized inference provides efficiency, techniques like Importance Weighted Autoencoders (IWAEs) offer a path to a tighter Evidence Lower Bound (ELBO) and potentially more accurate posterior approximations. This practical section will guide you through implementing an IWAE and exploring its effects. We'll also touch upon considerations for other advanced inference strategies.
Recall that the standard VAE maximizes the ELBO:
LELBO(x)=Eqϕ(z∣x)[logpθ(x∣z)]−DKL(qϕ(z∣x)∣∣p(z))The IWAE, introduced by Burda et al. (2015), provides a tighter lower bound on the log marginal likelihood logpθ(x) by using multiple samples from the approximate posterior qϕ(z∣x) for each data point x. The IWAE objective, often denoted LK, is:
LK(x)=Ez1,...,zK∼qϕ(z∣x)[log(K1k=1∑Kqϕ(zk∣x)pθ(x,zk))]where zk are K independent samples drawn from qϕ(z∣x). This can be rewritten in a more practical form for implementation:
LK(x)=Ez1,...,zK∼qϕ(z∣x)[log(K1k=1∑Kexp(logpθ(x∣zk)+logp(zk)−logqϕ(zk∣x)))]Notice that when K=1, L1 recovers the standard ELBO (up to a subtle difference in how the expectation is handled in practice, but for optimization, it's very similar). As K→∞, LK→logpθ(x).
Let's outline the steps to modify a standard VAE implementation (which you might have from the practical in Chapter 2) into an IWAE. We'll assume you have an encoder that outputs parameters (e.g., mean μϕ(x) and log-variance logσϕ2(x)) for qϕ(z∣x), and a decoder that models pθ(x∣z).
For each input x in a mini-batch, instead of drawing one sample z from qϕ(z∣x), you need to draw K samples. If your encoder outputs μϕ(x) and σϕ(x) for a Gaussian qϕ(z∣x), the reparameterization trick is applied K times: zk=μϕ(x)+σϕ(x)⊙ϵk, where ϵk∼N(0,I) for k=1,…,K.
In terms of tensor operations (e.g., in PyTorch or TensorFlow):
mu
and logvar
of shape (batch_size, latent_dim)
.(batch_size, K, latent_dim)
or process them in a way that allows for K samples per input.eps
of shape (batch_size, K, latent_dim)
.z_samples
of shape (batch_size, K, latent_dim)
.# Pseudocode/PyTorch-like illustration
# mu, logvar shapes: (batch_size, latent_dim)
# K: number of importance samples
# Expand mu and logvar for K samples
mu_expanded = mu.unsqueeze(1).expand(-1, K, -1) # (batch_size, K, latent_dim)
logvar_expanded = logvar.unsqueeze(1).expand(-1, K, -1) # (batch_size, K, latent_dim)
std_expanded = torch.exp(0.5 * logvar_expanded)
# Sample epsilon
epsilon = torch.randn_like(std_expanded) # (batch_size, K, latent_dim)
# Generate K latent samples per input
z_samples = mu_expanded + std_expanded * epsilon # (batch_size, K, latent_dim)
For each sample zk, we need to compute its unnormalized log importance weight, wk′:
logwk′=logpθ(x∣zk)+logp(zk)−logqϕ(zk∣x)(batch_size, data_dim)
, it might need to be expanded to (batch_size, K, data_dim)
to match z_samples
when calculating reconstruction loss for each sample.These components will result in a tensor of log weights, say log_w_prime
, of shape (batch_size, K)
.
The IWAE objective involves log(K1∑kexp(logwk′)). Directly summing exponentials can lead to numerical underflow or overflow. The log-sum-exp
(LSE) trick is essential here:
where α=maxkak. The IWAE loss for a single data point x is then:
LK(x)=LSE(logw1′,…,logwK′)−logKThe final loss for the mini-batch is the average of LK(x) over all x in the batch.
# Pseudocode/PyTorch-like illustration for the IWAE loss term
# log_p_x_given_z: (batch_size, K), reconstruction log-likelihood for each sample
# log_p_z: (batch_size, K), prior log-prob for each sample
# log_q_z_given_x: (batch_size, K), approximate posterior log-prob for each sample
log_w_prime = log_p_x_given_z + log_p_z - log_q_z_given_x # (batch_size, K)
# Log-sum-exp for numerical stability
log_sum_exp_w = torch.logsumexp(log_w_prime, dim=1) # (batch_size,)
# IWAE objective for each data point
iwae_elbo_per_sample = log_sum_exp_w - torch.log(torch.tensor(K, dtype=torch.float32))
# Average over batch
batch_iwae_elbo = torch.mean(iwae_elbo_per_sample)
# The loss to minimize is -batch_iwae_elbo
loss = -batch_iwae_elbo
A Note on Dimensions: Carefully manage tensor dimensions. When you pass zk (shape
(batch_size, K, latent_dim)
) to the decoder, it might process it as(batch_size * K, latent_dim)
. The output reconstructions will then be(batch_size * K, data_dim)
. You'll need to reshape x to match this for computing logpθ(x∣zk), and then reshape the resulting log-likelihoods back to(batch_size, K)
.
Once you have your IWAE implemented, it's time to experiment:
Vary K: Train IWAE models with different values of K (e.g., K=1,5,10,50).
The plot above illustrates a typical trend: as K increases, the IWAE bound (LK) improves (becomes less negative), but computational cost per epoch also rises.
Active Units: If you are also interested in disentanglement or representation learning (covered in Chapter 5), investigate if using IWAE (with K>1) affects the number of "active units" in the latent space compared to a standard VAE. Sometimes, a tighter bound can prevent premature KL vanishing.
Posterior Collapse: For models prone to posterior collapse (where qϕ(z∣x) becomes very similar to p(z), making the latent variables uninformative), does IWAE with K>1 help mitigate this issue? The more accurate gradients provided by IWAE might offer better optimization paths.
While IWAE focuses on improving the bound via multiple samples, other techniques modify the structure of qϕ(z∣x) or the inference process itself:
Structured Variational Inference:
Auxiliary Variables (e.g., Semi-Amortized VI, Hierarchical VAEs):
Adversarial Variational Bayes (AVB):
Implementing these advanced techniques often involves a deeper dive into probabilistic modeling and careful network architecture design. The common thread is moving beyond the simple mean-field Gaussian assumption for qϕ(z∣x) to capture more complex posterior structures or to obtain better estimates of the model evidence.
This practical exercise with IWAEs should provide a solid foundation for understanding how to improve VAE inference. By tightening the ELBO, IWAEs can lead to better generative models and more faithful latent representations. The increased computational cost is a trade-off, but often, a moderate K (e.g., 5-10) can offer a good balance.
As you progress, consider how the principles of IWAE (multiple samples for better estimates) or the structural enhancements from other advanced methods could be combined or adapted for your specific VAE applications. The ability to critically assess and improve the inference mechanism is a hallmark of advanced VAE development.
Was this section helpful?
© 2025 ApX Machine Learning