Okay, let's dive into Coordinate Ascent Variational Inference (CAVI). We've established that Variational Inference (VI) aims to find the best approximation q(z) to the true posterior p(z∣x) within a chosen family Q, typically by maximizing the Evidence Lower Bound (ELBO). The common choice, the mean-field variational family, assumes independence between latent variables (or groups of variables) in the approximation:
q(z)=∏j=1Mqj(zj)
where z={z1,…,zM} are the latent variables. The question now is: how do we find the optimal factors qj(zj) that maximize the ELBO, L(q)?
CAVI provides an iterative approach. It maximizes the ELBO with respect to each factor qj(zj) one at a time, holding all other factors qi=j(zi) fixed. This is analogous to coordinate ascent optimization applied in the functional space of distributions.
Deriving the CAVI Update Rule
Let's focus on optimizing the ELBO with respect to a single factor, say qj(zj). Recall the ELBO:
L(q)=Eq[logp(x,z)]−Eq[logq(z)]
Substituting the mean-field assumption q(z)=∏i=1Mqi(zi):
L(q)=∫⋯∫(∏i=1Mqi(zi))(logp(x,z)−∑i=1Mlogqi(zi))dz1…dzM
We want to find the qj(zj) that maximizes this expression, holding qi=j fixed. Let's isolate the terms involving qj(zj):
L(q)=∫qj(zj)(∫⋯∫(∏i=jqi(zi))logp(x,z)(∏i=jdzi))dzj−∫qj(zj)logqj(zj)dzj+const
Here, "const" represents terms that do not depend on qj(zj). The term inside the large parentheses in the first integral is the expectation of the log joint distribution with respect to all factors except qj. Let's denote this as Ei=j[logp(x,z)]. So we have:
L(q)=∫qj(zj)Ei=j[logp(x,z)]dzj−∫qj(zj)logqj(zj)dzj+const
Let logp~j(zj)=Ei=j[logp(x,z)]. The expression becomes:
L(q)=∫qj(zj)logp~j(zj)dzj−∫qj(zj)logqj(zj)dzj+const
L(q)=−∫qj(zj)logp~j(zj)qj(zj)dzj+const
L(q)=−KL(qj(zj)∣∣p~j(zj))+const
Maximizing L(q) with respect to qj(zj) is equivalent to minimizing the Kullback-Leibler (KL) divergence between qj(zj) and p~j(zj). The KL divergence is minimized (becomes zero) when qj(zj)=p~j(zj). Therefore, the optimal solution qj∗(zj) satisfies:
logqj∗(zj)=Ei=j[logp(x,z)]+const
Exponentiating both sides gives:
qj∗(zj)∝exp(Ei=j[logp(x,z)])
The constant term is determined by the normalization requirement ∫qj∗(zj)dzj=1.
This is the core update rule for CAVI. It states that the optimal distribution for the j-th factor qj(zj) is proportional to the exponentiated expected log of the joint probability p(x,z), where the expectation is taken with respect to the current distributions of all other factors qi=j(zi).
The CAVI Algorithm
The CAVI algorithm proceeds as follows:
- Initialization: Initialize the parameters of each variational factor qj(zj) for j=1,…,M. This might involve initializing means, variances, or other relevant parameters depending on the distributional form chosen for each qj.
- Iteration: Repeat until convergence:
- For each factor j=1,…,M:
- Compute the expectation Ei=j[logp(x,z)]. This requires using the current parameters of the distributions qi=j(zi).
- Update the factor qj(zj) using the rule:
qj(zj)←∫exp(Ei=j[logp(x,z)])dzjexp(Ei=j[logp(x,z)])
In practice, this often means updating the parameters of the distribution qj(zj) based on the computed expectation.
- Convergence Check: Monitor the ELBO or the parameters of the variational factors. Stop when the change between iterations falls below a predefined threshold.
The CAVI algorithm iteratively updates each factor qj(zj) based on the expectation of the log joint probability, conditioned on the current state of the other factors, until convergence.
The Role of Conjugacy
The practicality of CAVI often hinges on the structure of the model p(x,z). If the complete conditional p(zj∣x,z¬j) (where z¬j denotes all variables except zj) is in the same family as the prior for zj, the model exhibits conditional conjugacy. In such cases, the CAVI update for qj∗(zj) often results in a distribution of the same parametric form as the initial choice for qj(zj).
For instance, if qj(zj) is chosen to be Gaussian, and the model structure allows, the update qj∗(zj) derived from the expectation Ei=j[logp(x,z)] might also be Gaussian. The update step then simplifies to calculating the new mean and variance parameters based on the expected values (moments) derived from the other factors qi=j. This makes the implementation significantly easier, as we only need to track and update the parameters (e.g., means, variances) of the factors.
Advantages and Disadvantages of CAVI
Advantages:
- Guaranteed Convergence: The ELBO is guaranteed to increase (or stay the same) at each step of CAVI, ensuring convergence to a local optimum.
- Deterministic: Unlike MCMC, CAVI is a deterministic optimization algorithm, yielding the same result given the same initialization.
- Simplicity (with Conjugacy): When conditional conjugacy applies, the updates can often be derived analytically and involve updating parameters of standard distributions.
Disadvantages:
- Local Optima: CAVI optimizes the ELBO, which is generally non-convex. The algorithm may converge to a local optimum, potentially providing a poor approximation of the true posterior. Different initializations might lead to different results.
- Mean-Field Limitation: The accuracy is fundamentally limited by the expressiveness of the chosen variational family. The mean-field assumption (independence between factors) might be too restrictive for models with strong posterior dependencies.
- Analytical Calculations: Deriving the update rule requires calculating Ei=j[logp(x,z)]. This can be complex or analytically intractable for some models.
- Scalability: Traditional CAVI requires iterating through the entire dataset to compute expectations for global variable updates (parameters governing all data points), which doesn't scale well to very large datasets.
CAVI provides a foundational understanding of variational optimization. While effective for moderately sized problems with suitable model structures, its limitations, particularly regarding scalability and the need for analytical derivations, motivate the development of more advanced techniques like Stochastic Variational Inference (SVI) and Black Box Variational Inference (BBVI), which we will explore next.