Okay, let's put the theory of scalable Variational Inference into practice. In the previous sections, we established that calculating the exact posterior p(z∣x) is often intractable, and while MCMC offers an alternative, it can be slow for large datasets. Variational Inference reformulates inference as an optimization problem, seeking an approximate distribution q(z) that minimizes the KL divergence to the true posterior, equivalent to maximizing the Evidence Lower Bound (ELBO):
L(q)=Eq(z)[logp(x,z)]−Eq(z)[logq(z)]For large datasets, calculating the expectation Eq(z)[logp(x,z)] over the entire dataset x at each optimization step is computationally prohibitive. Stochastic Variational Inference (SVI) addresses this by using mini-batches of data to estimate the gradients of the ELBO, allowing us to use stochastic optimization techniques like Adam. Black Box Variational Inference (BBVI) further extends this by providing a general way to compute gradients even when the ELBO gradient isn't analytically tractable, relying on the score function estimator.
In this practical exercise, we'll implement SVI for a Bayesian Logistic Regression model using the Pyro probabilistic programming library. This will demonstrate how to handle larger datasets efficiently.
Implement and train a Bayesian Logistic Regression model using Stochastic Variational Inference (SVI) on a simulated dataset. We will monitor the ELBO during training and examine the learned approximate posterior distributions for the model parameters.
First, ensure you have the necessary libraries installed. You'll primarily need torch
and pyro-ppl
.
# Standard libraries
import torch
import torch.nn as nn
from torch.distributions import constraints
import numpy as np
import matplotlib.pyplot as plt
# Pyro libraries
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
# Ensure reproducibility
pyro.set_rng_seed(101)
Let's create a simple binary classification dataset. We'll simulate data where the classification depends linearly on two features, passed through a sigmoid function.
# Generate synthetic data
N = 1000 # Number of data points
D = 2 # Number of features
true_weights = torch.tensor([1.5, -2.0])
true_bias = torch.tensor([0.5])
# Generate features (normally distributed)
X_data = torch.randn(N, D)
# Calculate linear combination and probability
linear_combination = X_data @ true_weights + true_bias
prob = torch.sigmoid(linear_combination)
# Generate binary labels based on the probability
y_data = torch.bernoulli(prob).float() # Ensure labels are float for potential loss calculation needs
print(f"Generated X_data shape: {X_data.shape}")
print(f"Generated y_data shape: {y_data.shape}")
In Pyro, we define our probabilistic model as a Python function that takes the data as input. Inside this function, we use pyro.sample
statements to declare random variables (parameters with priors, and the observed likelihood).
def logistic_regression_model(features, labels):
# Priors for weights and bias
# Assume standard normal priors (mean=0, std=1)
# The shape of weights should match the number of features D
weights = pyro.sample("weights", dist.Normal(torch.zeros(D), torch.ones(D)))
# Bias is a single scalar value
bias = pyro.sample("bias", dist.Normal(0., 1.))
# Define the linear combination (logits)
# Use pyro.plate to indicate conditional independence over data points
with pyro.plate("data", size=features.shape[0]):
logits = torch.sigmoid(features @ weights + bias)
# Define the likelihood using Bernoulli distribution
# obs=labels connects the distribution to the observed data
pyro.sample("obs", dist.Bernoulli(logits), obs=labels)
This logistic_regression_model
function defines the generative process: sample weights and bias from priors, calculate the probability for each data point using the logistic function, and then sample the observed labels from a Bernoulli distribution based on these probabilities.
The guide function, q(z), specifies the family of distributions we'll use to approximate the posterior. For mean-field VI, we assume the latent variables (weights and bias in our case) are independent in the approximate posterior. We typically use distributions with learnable parameters (e.g., Normal distributions where mean and standard deviation are parameters to be optimized).
def mean_field_guide(features, labels):
# Define variational parameters for weights
# 'weight_loc' and 'weight_scale' will be learned during optimization
weight_loc = pyro.param("weight_loc", torch.randn(D))
# Scale must be positive, so we use softplus transform and constraints
weight_scale = pyro.param("weight_scale", torch.randn(D))
weight_scale_constrained = torch.nn.functional.softplus(weight_scale)
# Define variational parameters for bias
bias_loc = pyro.param("bias_loc", torch.randn(1))
bias_scale = pyro.param("bias_scale", torch.randn(1))
bias_scale_constrained = torch.nn.functional.softplus(bias_scale)
# Sample latent variables from the guide distributions
# These names must match the names used in the model function
pyro.sample("weights", dist.Normal(weight_loc, weight_scale_constrained))
pyro.sample("bias", dist.Normal(bias_loc, bias_scale_constrained))
Here, pyro.param
declares learnable parameters. weight_loc
, weight_scale
, bias_loc
, and bias_scale
are the parameters of our approximate Normal distributions for the weights and bias, respectively. We use the softplus function to ensure the scale parameters remain positive.
Now we configure the SVI algorithm. We need an optimizer (Adam is common) and the loss function (Trace_ELBO, which computes a Monte Carlo estimate of the negative ELBO).
# Setup the optimizer
optimizer = Adam({"lr": 0.01}) # Learning rate
# Setup the inference algorithm
svi = SVI(model=logistic_regression_model,
guide=mean_field_guide,
optim=optimizer,
loss=Trace_ELBO())
We run the SVI optimization loop. In each step, svi.step
computes an estimate of the ELBO gradient using a mini-batch of data and updates the parameters defined in the guide (weight_loc
, weight_scale
, etc.) via the optimizer.
# Training configuration
num_iterations = 2000
batch_size = 100
n_data = X_data.shape[0]
elbo_history = []
# Clear Pyro's parameter store before starting training
pyro.clear_param_store()
print("Starting SVI training...")
for j in range(num_iterations):
# Create mini-batch indices
indices = torch.randperm(n_data)[:batch_size]
mini_batch_features = X_data[indices]
mini_batch_labels = y_data[indices]
# Calculate the loss and take a gradient step
loss = svi.step(mini_batch_features, mini_batch_labels)
elbo_history.append(-loss) # Store ELBO (negative loss)
if j % 200 == 0:
print(f"[Iteration {j+1}/{num_iterations}] ELBO: {-loss:.2f}")
print("SVI training finished.")
After training, the Pyro parameter store contains the optimized values for the guide's parameters. We can access them to understand the approximate posterior.
# Retrieve the learned parameters of the guide
learned_params = {}
for name, param_val in pyro.get_param_store().items():
learned_params[name] = param_val.detach().numpy()
print("\nLearned Variational Parameters:")
print(f"Weight Mean (loc): {learned_params['weight_loc']}")
print(f"Weight Std Dev (scale): {np.log(1 + np.exp(learned_params['weight_scale']))}") # Apply softplus inverse
print(f"Bias Mean (loc): {learned_params['bias_loc']}")
print(f"Bias Std Dev (scale): {np.log(1 + np.exp(learned_params['bias_scale']))}")
# Compare with true values
print(f"\nTrue Weights: {true_weights.numpy()}")
print(f"True Bias: {true_bias.numpy()}")
You should see that the learned means (weight_loc
, bias_loc
) are reasonably close to the true weights and bias used to generate the data. The learned standard deviations (weight_scale
after transformation, bias_scale
after transformation) give us a measure of uncertainty about these parameters.
Let's plot the ELBO values recorded during training. We expect the ELBO to increase and plateau, indicating convergence of the optimization.
# Plot ELBO history
plt.figure(figsize=(10, 4))
plt.plot(elbo_history)
plt.title("ELBO Convergence During SVI Training")
plt.xlabel("Iteration")
plt.ylabel("ELBO")
plt.grid(True)
plt.show()
The Evidence Lower Bound (ELBO) generally increases during training, indicating that the variational distribution is becoming a better approximation of the true posterior. The curve typically flattens out as the optimization converges.
We can also visualize the approximate posterior distributions for the weights compared to their true values.
# Visualize the approximate posterior for weights
from scipy.stats import norm
# Create subplots
fig, axs = plt.subplots(1, D, figsize=(12, 5), sharey=True)
fig.suptitle('Approximate Posterior Distributions for Weights')
for i in range(D):
# Get mean and std dev for the current weight
loc = learned_params['weight_loc'][i]
scale = np.log(1 + np.exp(learned_params['weight_scale'][i])) # Apply softplus
# Generate x values around the mean
x = np.linspace(loc - 3*scale, loc + 3*scale, 200)
# Calculate probability density function (PDF)
y = norm.pdf(x, loc, scale)
# Plot the PDF
axs[i].plot(x, y, label=f'Approx. Posterior q(w_{i})', color='#1c7ed6')
# Plot the true weight value as a vertical line
axs[i].axvline(true_weights[i].numpy(), color='#f03e3e', linestyle='--', label=f'True w_{i}')
axs[i].set_title(f'Weight {i+1}')
axs[i].set_xlabel('Value')
axs[i].legend()
axs[i].grid(True, linestyle=':', alpha=0.6)
axs[0].set_ylabel('Probability Density')
plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent title overlap
plt.show()
Approximate posterior distributions for the model weights obtained via SVI. The plots show the learned Normal distributions (blue curves) centered near the true parameter values (red dashed lines), with widths indicating the uncertainty captured by the variational approximation.
This hands-on example demonstrates the core workflow of SVI: defining a model and guide, setting up the SVI optimizer and loss, and running the training loop with mini-batches. SVI allows us to apply Bayesian inference to much larger datasets than would be feasible with exact methods or standard MCMC, by leveraging stochastic optimization.
We used a simple mean-field guide here. For more complex posterior geometries, exploring more advanced variational families (covered briefly in Section 3.7) might yield better approximations, often at the cost of increased computational complexity during optimization. The choice between SVI, BBVI, and MCMC methods often depends on the specific model, dataset size, required accuracy, and available computational resources, a trade-off we explored in Section 3.8.
© 2025 ApX Machine Learning