Recall from the chapter introduction that Variational Inference (VI) transforms the challenge of computing the posterior p(z∣x) into an optimization problem. We aim to find a distribution q(z) within a chosen family Q that best approximates the true posterior, typically by maximizing the Evidence Lower Bound (ELBO):
L(q)=Eq(z)[logp(x,z)]−Eq(z)[logq(z)]The crucial step is selecting the family Q. If we allowed Q to contain all possible distributions, the optimal q∗(z) would be the true posterior p(z∣x) itself, bringing us back to the original intractable problem. Therefore, the essence of practical VI lies in choosing a restricted, tractable family for q(z).
One of the most widely used strategies is the mean-field variational family. This approach introduces a significant simplification by assuming that the latent variables z=(z1,z2,...,zM) in the approximating distribution q(z) are mutually independent. Mathematically, we enforce a fully factorized structure:
q(z)=j=1∏Mqj(zj;λj)Here, the joint variational distribution q(z) is broken down into a product of M independent factors, where each factor qj(zj;λj) governs only a single latent variable zj (or sometimes a block of variables, though full factorization is common). Each factor qj is itself a probability distribution, often belonging to a simple parametric family (like Gaussian or Dirichlet) and parameterized by its own set of variational parameters λj. The goal of VI then becomes finding the optimal parameters {λj}j=1M that maximize the ELBO.
This factorization assumption has a direct impact on the ELBO terms. The entropy term Eq(z)[logq(z)] becomes particularly manageable:
Eq(z)[logq(z)]=Eq(z)[j=1∑Mlogqj(zj)]=j=1∑MEqj(zj)[logqj(zj)]The expectation of the sum becomes the sum of expectations, and because each qj only depends on zj, the expectation Eq(z) simplifies to Eqj(zj) for the j-th term. This means the total entropy is just the sum of the individual entropies of the factors. If the factors qj belong to standard exponential families, their entropies often have closed-form expressions, making this term easy to compute.
The first term, Eq(z)[logp(x,z)], also simplifies conceptually. The expectation is taken over the factorized distribution q(z), meaning we integrate or sum over each zj according to its factor qj(zj). The specific calculation depends heavily on the structure of the model's joint probability p(x,z).
How do we find the optimal form for each factor qj(zj)? We can optimize the ELBO with respect to one factor qj while holding the others (qk for k=j) fixed. This is the core idea behind algorithms like Coordinate Ascent Variational Inference (CAVI), which we'll discuss in the next section.
Let's isolate the terms in the ELBO that depend on a specific factor qj(zj). We can rewrite the ELBO as:
L(q)=∫qj(zj)∫k=j∏qk(zk)logp(x,z)dz−jdzj−∫qj(zj)logqj(zj)dzj+terms not depending on qjwhere dz−j denotes integration over all variables z except zj. The term inside the parentheses in the first integral is an expectation of the log joint probability taken with respect to all factors except qj:
Eq−j[logp(x,z)]≜∫logp(x,z)k=j∏qk(zk)dz−jThis expectation Eq−j[logp(x,z)] yields a function that depends on zj (and x, and the parameters of the fixed qk's). Let's denote logp~j(zj)=Eq−j[logp(x,z)]. Then the terms in the ELBO depending on qj(zj) are:
Lj=∫qj(zj)logp~j(zj)dzj−∫qj(zj)logqj(zj)dzjThis expression looks familiar. It is equal to the negative Kullback-Leibler (KL) divergence between qj(zj) and an unnormalized distribution proportional to p~j(zj), plus a constant (the log normalizer of p~j(zj)):
Lj=−KL(qj(zj)∣∣Cjp~j(zj))+logCjwhere Cj=∫p~j(zj)dzjMaximizing Lj (and thus the full ELBO with respect to qj, holding others fixed) is equivalent to minimizing the KL divergence KL(qj(zj)∣∣Cjp~j(zj)). The minimum KL divergence value of zero is achieved when qj(zj) is exactly equal to the normalized distribution Cjp~j(zj). Therefore, the optimal solution for the factor qj∗(zj) satisfies:
qj∗(zj)∝exp(Eq−j[logp(x,z)])Or, equivalently:
logqj∗(zj)=Eq−j[logp(x,z)]+constantThis important result provides a recipe for finding the optimal form of each factor qj∗(zj), assuming all other factors qk(zk) (k=j) are fixed. It states that the optimal log-density for zj is obtained by taking the log of the model's joint probability p(x,z) and then averaging over all other variables zk (k=j) according to their current variational distributions qk(zk). This forms the basis for iterative update schemes like CAVI, where we cycle through the variables, updating each qj based on the current estimates of the others.
The primary strength of the mean-field approximation is computational tractability. By breaking the dependency structure among latent variables within the variational approximation q, we convert a potentially complex optimization problem over a high-dimensional distribution into a series of potentially simpler optimizations over lower-dimensional factors qj. If the model structure and choice of qj families are compatible (e.g., using conjugate priors in the model often leads to recognizable forms for qj∗), these updates can sometimes be derived analytically.
However, this simplification comes at a cost. The central assumption is that the variational posterior factors are independent: q(z)=∏jqj(zj). If the true posterior p(z∣x) exhibits significant correlations between the latent variables zj, the mean-field approximation q(z) will fail to capture these dependencies by definition.
Consider a simple 2D Gaussian example. If the true posterior shows strong negative correlation between z1 and z2, the best mean-field approximation q(z1,z2)=q1(z1)q2(z2) will be a Gaussian with zero correlation (its contours will be aligned with the axes), even if it manages to center itself correctly and approximate the marginal variances.
Illustration comparing a true posterior p(z1,z2∣x) with correlation (dashed orange contours) and its best mean-field approximation q(z1)q(z2) (solid blue contours). The approximation centers correctly but enforces independence, thus missing the correlation structure present in the true posterior.
This inability to capture posterior correlations is a well-known characteristic of mean-field VI. Because the KL divergence KL(q∣∣p) penalizes placing probability mass with q where p has none, the resulting q(z) tends to be more "compact" or concentrated around the mode than the true posterior p(z∣x). This often leads to an underestimation of posterior variances and potentially overly confident uncertainty estimates.
Despite these limitations, the mean-field assumption is foundational in variational inference. Its computational advantages are significant, making Bayesian inference feasible for many complex models and large datasets where MCMC methods might struggle with convergence time or memory requirements. Understanding the mean-field approximation, its derivation via the optimal factor updates, and its inherent limitations regarding posterior dependencies and variance estimation is essential for applying VI effectively and interpreting its results critically. We will now explore the Coordinate Ascent Variational Inference (CAVI) algorithm, which directly implements the iterative updates derived from the mean-field assumption.
© 2025 ApX Machine Learning