As we've discussed, the standard amortized inference network in VAEs, while efficient, often employs a simple distribution (like a diagonal Gaussian) for qϕ(z∣x). This simplicity can be a bottleneck, preventing qϕ(z∣x) from accurately approximating a potentially complex true posterior pθ(z∣x). This section explores two powerful strategies to create more expressive and accurate approximate posteriors: leveraging auxiliary variables and employing semi-amortized inference schemes.
One way to enhance the flexibility of the approximate posterior qϕ(z∣x) without making its direct functional form overly complex is to introduce auxiliary random variables. These variables are not part of the original generative model pθ(x∣z) but are used within the inference network to help shape a richer distribution for z.
Let's denote these auxiliary variables by a. Instead of defining qϕ(z∣x) directly, we define a joint distribution qϕ(z,a∣x) over both the original latent variables z and these new auxiliary variables a. A common factorization for this joint distribution is hierarchical:
qϕ(z,a∣x)=qϕ(z∣x,a)qϕ(a∣x)Here, qϕ(a∣x) is an inference network that maps an input x to the parameters of a distribution over a. Then, qϕ(z∣x,a) is another inference network that maps x and a sample a∼qϕ(a∣x) to the parameters of a distribution over z.
The resulting marginal distribution for z, qϕ(z∣x)=∫qϕ(z∣x,a)qϕ(a∣x)da, can be significantly more complex and flexible than if we had modeled qϕ(z∣x) directly with a simple family (e.g., a single Gaussian). Think of it as using a to "steer" or "refine" the inference for z. For example, qϕ(a∣x) could capture some high-level aspects of the posterior, and qϕ(z∣x,a) could then model finer details conditioned on these aspects.
To incorporate this into the VAE framework, we adjust the Evidence Lower Bound (ELBO). We treat a as additional latent variables and assume a simple prior p(a) for them (e.g., a standard Normal distribution, p(a)=N(0,I)). The ELBO for this augmented system is:
L(x;θ,ϕ)=Eqϕ(z,a∣x)[logpθ(x∣z)]−DKL(qϕ(z,a∣x)∣∣p(z)p(a))Note that the generative model pθ(x∣z) still only depends on z. The auxiliary variables a are only "seen" by the inference machinery and the prior p(a). Using our chosen factorization qϕ(z,a∣x)=qϕ(z∣x,a)qϕ(a∣x) and assuming p(z,a)=p(z)p(a) (i.e., z and a are independent in the prior), the KL divergence term can be decomposed:
DKL(qϕ(z,a∣x)∣∣p(z)p(a))=Eqϕ(a∣x)[DKL(qϕ(z∣x,a)∣∣p(z))]+DKL(qϕ(a∣x)∣∣p(a))So, the ELBO becomes:
L=Eqϕ(a∣x)[Eqϕ(z∣x,a)[logpθ(x∣z)]−DKL(qϕ(z∣x,a)∣∣p(z))]−DKL(qϕ(a∣x)∣∣p(a))This expression looks like a VAE objective where qϕ(a∣x) acts as an "encoder" for a, and then, conditioned on a, qϕ(z∣x,a) acts as another "encoder" for z. The overall structure allows qϕ(z∣x) to implicitly represent a mixture of simpler distributions, leading to a much richer family for the approximate posterior.
Benefits:
Costs:
Models like Auxiliary Deep Generative Models (ADGMs) and some variants of hierarchical VAEs (when applied to the inference side) are examples of this approach. This technique is distinct from Normalizing Flows (which transform a simple noise distribution into a complex posterior using invertible functions) but can be complementary.
Amortized variational inference, where a single neural network qϕ(z∣x) directly outputs the parameters of the approximate posterior for any given x, is computationally efficient. However, it makes a strong assumption: that a single set of network parameters ϕ can provide optimal (or near-optimal) variational parameters for all datapoints. This can lead to an "amortization gap", the difference in ELBO quality between what a fully amortized qϕ(z∣x) can achieve and what could be achieved if we optimized the variational parameters for each datapoint individually.
Semi-amortized variational inference aims to bridge this gap. The core idea is to use the amortized inference network to provide a good initialization for the variational parameters for a specific datapoint xi. Then, these initial parameters are refined through a few steps of optimization, specifically for that xi, by directly maximizing the ELBO with respect to the variational parameters for that instance.
Let λ denote the parameters of the approximate posterior for a single datapoint xi (e.g., if q(z∣xi) is Gaussian, λ=(μi,σi)). The process is as follows:
The following diagram illustrates this refinement process:
The semi-amortized inference process: An amortized network provides an initial estimate of posterior parameters, which are then refined through instance-specific optimization.
Benefits:
Costs:
This approach is particularly useful when the true posterior has high variance across different datapoints, making it difficult for a single amortized network to perform well universally. The number of refinement steps T is a hyperparameter; even a small number of steps (e.g., T=5 to 10) can often yield substantial improvements.
Auxiliary variables and semi-amortized inference are not mutually exclusive. One could, for instance, define an expressive posterior family using auxiliary variables and then use semi-amortized inference to fine-tune the parameters of this richer posterior for each data point.
When deciding whether to use these advanced inference techniques, consider the following:
In practice, starting with a well-tuned standard VAE and then exploring these techniques can be a good strategy if further improvements in posterior approximation are needed. The choice depends on the specific application, available computational resources, and the desired trade-off between model performance and inference speed. Both methods offer valuable tools for pushing the boundaries of what VAEs can achieve by enabling more accurate and flexible posterior inference.
Was this section helpful?
© 2025 ApX Machine Learning