Alright, let's roll up our sleeves and get practical. You've grappled with the mathematical underpinnings of Variational Autoencoders, from deriving the Evidence Lower Bound (ELBO) to understanding the reparameterization trick and the role of KL divergence. Now, it's time to translate that theory into a working VAE. This hands-on section will guide you through implementing a VAE from scratch, training it, and then performing essential diagnostics to understand its behavior and common failure modes. Our goal is not just to build a model, but to connect its tangible outputs and training dynamics back to the principles discussed earlier in this chapter.
We'll use Python and a popular deep learning framework like PyTorch for this exercise. The concepts are transferable to other frameworks like TensorFlow with minimal changes in syntax. We'll focus on the MNIST dataset of handwritten digits, a classic choice that allows us to concentrate on the VAE's mechanics without getting bogged down by complex data preprocessing or overly large network architectures.
Before we begin, ensure you have PyTorch installed, along with torchvision
for the MNIST dataset and matplotlib
for visualizations.
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
We'll define some hyperparameters upfront:
# Hyperparameters
latent_dims = 20 # Dimensionality of the latent space
image_size = 28 * 28 # MNIST images are 28x28
batch_size = 128
learning_rate = 1e-3
num_epochs = 30 # Adjust as needed
And load the MNIST dataset:
# MNIST Dataset
transform = transforms.Compose([
transforms.ToTensor(), # Converts to [0, 1] range and C, H, W format
# We don't normalize for VAEs with Bernoulli output,
# as pixel values are treated as probabilities.
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
A VAE consists of two main neural networks: an encoder and a decoder.
High-level data flow in a Variational Autoencoder.
The encoder, parameterized by ϕ, takes an input data point x (an image in our case) and outputs the parameters of the approximate posterior distribution qϕ(z∣x). For a Gaussian posterior, these parameters are the mean μ and the logarithm of the variance logσ2 (log-variance). Using log-variance improves numerical stability and ensures that the variance σ2 is always positive.
class Encoder(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(Encoder, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
self.relu = nn.ReLU()
def forward(self, x):
h = self.relu(self.fc1(x))
mu = self.fc_mu(h)
log_var = self.fc_logvar(h) # log_var for numerical stability
return mu, log_var
Here, input_dim
is image_size
, and latent_dim
is latent_dims
. hidden_dim
can be chosen, e.g., 400.
To allow gradients to flow back through the sampling process (sampling z∼qϕ(z∣x)), we use the reparameterization trick. If z∼N(μ,σ2), we can write z=μ+σ⋅ϵ, where ϵ∼N(0,I). The randomness is now externalized to ϵ.
def reparameterize(mu, log_var):
std = torch.exp(0.5 * log_var) # std = exp(log(std)) = exp(0.5 * log(var))
eps = torch.randn_like(std) # Sample epsilon from N(0, I)
return mu + eps * std
The decoder, parameterized by θ, takes a latent vector z and reconstructs the data point x. For MNIST, since pixel values are typically between 0 and 1, we can model the output as parameters of a Bernoulli distribution for each pixel. The decoder's final layer will output logits, which are then passed through a sigmoid function to get probabilities.
class Decoder(nn.Module):
def __init__(self, latent_dim, hidden_dim, output_dim):
super(Decoder, self).__init__()
self.fc1 = nn.Linear(latent_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
self.relu = nn.ReLU()
# We'll apply sigmoid in the loss function or for generation,
# as nn.BCEWithLogitsLoss is more stable.
def forward(self, z):
h = self.relu(self.fc1(z))
x_reconstructed_logits = self.fc2(h)
return x_reconstructed_logits
Here, output_dim
is image_size
.
Now, we combine the encoder and decoder into a single VAE model.
class VAE(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(VAE, self).__init__()
self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
self.decoder = Decoder(latent_dim, hidden_dim, input_dim) # output_dim is input_dim
def forward(self, x):
mu, log_var = self.encoder(x)
z = reparameterize(mu, log_var)
x_reconstructed_logits = self.decoder(z)
return x_reconstructed_logits, mu, log_var
# Initialize the model
model = VAE(input_dim=image_size, hidden_dim=400, latent_dim=latent_dims)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
The VAE is trained by maximizing the ELBO, which is equivalent to minimizing the negative ELBO. As you recall, the ELBO consists of two terms:
The loss function to minimize is:
L(x,x^,μ,logσ2)=ReconstructionLoss+KLDivergenceFor Gaussian qϕ(z∣x)=N(z∣μ,diag(σ2)) and p(z)=N(z∣0,I), the KL divergence has a convenient analytical form:
DKL(qϕ(z∣x)∣∣p(z))=−21j=1∑D(1+log(σj2)−μj2−σj2)where D is the dimensionality of the latent space.
def loss_function(x_reconstructed_logits, x, mu, log_var):
# Reconstruction loss (using BCEWithLogitsLoss for numerical stability)
# It expects logits as input and raw pixels as target.
# reduction='sum' to sum over all pixels and batch elements
recon_loss = nn.functional.binary_cross_entropy_with_logits(x_reconstructed_logits, x, reduction='sum')
# KL divergence
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
# We sum over latent dimensions, then average over batch
kld = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
return recon_loss + kld # This is the negative ELBO to be minimized.
Note: binary_cross_entropy_with_logits
averages over pixels if reduction='mean'
(default) or sums if reduction='sum'
. When summed, it's common to divide by batch_size
afterwards to keep loss magnitudes consistent across batches. Here, torch.sum
for KLD sums over latent dimensions for each item in the batch. We sum both losses for the total loss for the batch.
The training loop involves fetching a batch of data, passing it through the VAE, calculating the loss, and updating the model parameters using an optimizer like Adam.
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Lists to store loss components for plotting
train_losses = []
recon_losses = []
kld_losses = []
model.train() # Set model to training mode
for epoch in range(num_epochs):
epoch_loss = 0
epoch_recon_loss = 0
epoch_kld_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.view(-1, image_size).to(device) # Flatten images
# Forward pass
x_reconstructed_logits, mu, log_var = model(data)
# Compute loss
loss = loss_function(x_reconstructed_logits, data, mu, log_var)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Accumulate losses (normalized by dataset size for average ELBO estimate)
epoch_loss += loss.item()
# For individual components, ensure they are on the same scale
# If loss_function returns summed values, recon_loss_item would be from:
# nn.functional.binary_cross_entropy_with_logits(x_reconstructed_logits, data, reduction='sum').item()
# and kld_item from:
# (-0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())).item()
# Simplified: store batch average values if you divide loss by len(data) earlier
# For this example, let's calculate them separately for clarity if needed for plots
# Or just divide the total loss per batch item.
# The `loss.item()` above is the sum for the batch.
# To get per-sample loss:
# batch_recon_loss = nn.functional.binary_cross_entropy_with_logits(x_reconstructed_logits, data, reduction='sum')
# batch_kld = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
# epoch_recon_loss += batch_recon_loss.item()
# epoch_kld_loss += batch_kld.item()
# Calculate average losses for the epoch
avg_epoch_loss = epoch_loss / len(train_loader.dataset)
# avg_epoch_recon_loss = epoch_recon_loss / len(train_loader.dataset)
# avg_epoch_kld_loss = epoch_kld_loss / len(train_loader.dataset)
train_losses.append(avg_epoch_loss)
# recon_losses.append(avg_epoch_recon_loss) # If you track them separately
# kld_losses.append(avg_epoch_kld_loss) # If you track them separately
print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_epoch_loss:.4f}')
# If tracking components:
# print(f'Epoch [{epoch+1}/{num_epochs}], Avg Loss: {avg_epoch_loss:.4f}, Avg Recon Loss: {avg_epoch_recon_loss:.4f}, Avg KLD: {avg_epoch_kld_loss:.4f}')
Important Note on Loss Scaling: The exact values of reconstruction loss and KL divergence can vary significantly based on whether you sum or average over pixels, latent dimensions, and batch items. It's important to be consistent. The ELBO is an expectation, so averaging over the batch (and dataset) gives a Monte Carlo estimate. For the KL divergence, summing over latent dimensions and then averaging over the batch is standard. For reconstruction, summing over pixels and then averaging over the batch is also common. The code for loss_function
sums both, so loss.item()
is a sum over the batch. Dividing by len(train_loader.dataset)
normalizes it.
Once training is complete (or even during), it's essential to diagnose your VAE's performance.
Plotting the total loss (negative ELBO), reconstruction loss, and KL divergence over epochs provides insights into the training dynamics.
Example of VAE loss components during training (values are illustrative). The total loss and reconstruction loss should generally decrease. The KL divergence might increase initially as the encoder learns to utilize the latent space, then stabilize or slowly decrease.
Comparing original images to their reconstructions is a direct way to assess VAE performance.
model.eval() # Set model to evaluation mode
with torch.no_grad():
# Get a batch of test data
data, _ = next(iter(test_loader))
data = data.view(-1, image_size).to(device)
x_reconstructed_logits, _, _ = model(data)
# Apply sigmoid to get probabilities for visualization
x_reconstructed = torch.sigmoid(x_reconstructed_logits)
# Display original and reconstructed images
n_images = 10
plt.figure(figsize=(20, 4))
for i in range(n_images):
# Original
ax = plt.subplot(2, n_images, i + 1)
plt.imshow(data[i].cpu().numpy().reshape(28, 28), cmap='gray')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
if i == 0: ax.set_title("Original")
# Reconstruction
ax = plt.subplot(2, n_images, i + 1 + n_images)
plt.imshow(x_reconstructed[i].cpu().numpy().reshape(28, 28), cmap='gray')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
if i == 0: ax.set_title("Reconstructed")
plt.show()
Look for sharpness, preservation of key features, and overall fidelity.
A hallmark of generative models is their ability to produce new data. We can do this by sampling z from the prior p(z) (e.g., N(0,I)) and passing it through the decoder.
model.eval()
with torch.no_grad():
# Sample latent vectors from the prior N(0,I)
num_samples = 10
z_samples = torch.randn(num_samples, latent_dims).to(device)
# Decode them to generate images
generated_logits = model.decoder(z_samples)
generated_images = torch.sigmoid(generated_logits)
plt.figure(figsize=(15, 3))
for i in range(num_samples):
ax = plt.subplot(1, num_samples, i + 1)
plt.imshow(generated_images[i].cpu().numpy().reshape(28, 28), cmap='gray')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.suptitle("Generated Samples from Prior", fontsize=16)
plt.show()
Assess the quality and diversity of these generated samples. Do they look like MNIST digits? Are they varied?
Blurry Reconstructions/Generations: This is a common characteristic of VAEs, especially with simple decoders and Gaussian output assumptions. The model might be averaging over multiple plausible high-frequency details, leading to smoothness. Stronger decoders (e.g., using transposed convolutions for images, or autoregressive decoders discussed in Chapter 3) can help. The choice of reconstruction loss (e.g., L2 vs L1 vs BCE) can also impact sharpness.
Posterior Collapse (KL Vanishing): This occurs when the KL divergence term DKL(qϕ(z∣x)∣∣p(z)) trends towards zero during training. It means the approximate posterior qϕ(z∣x) becomes very close to the prior p(z), irrespective of the input x. Consequently, the latent variable z carries little to no information about x, and the decoder essentially learns to ignore z and generate an average output.
latent_dims
) or rapidly drops to near zero and stays there, you might have posterior collapse. The reconstructions might still look okay if the decoder is powerful enough to model the data distribution unconditionally, but the model won't be useful for representation learning or conditional generation."Holes" in the Latent Space: The prior p(z) encourages latent codes to be near the origin. However, not all points in the latent space sampled from p(z) might decode to realistic samples if the manifold of learned representations q(z∣x) is sparse or has gaps. Interpolating between latent codes of known data points can help visualize the smoothness of the learned manifold.
If your latent_dims
is 2, you can directly visualize how the decoder maps regions of the latent space to images. For higher dimensions, dimensionality reduction techniques like t-SNE or UMAP can project the latent codes z (obtained by encoding test data) into 2D, which can then be plotted and colored by their true labels (e.g., digit class for MNIST). A well-trained VAE often shows some clustering of similar data points in the latent space.
# Example for visualizing latent space (if latent_dims is suitable or using t-SNE)
# This snippet assumes you have encoded test_data into test_mu and test_labels
# from sklearn.manifold import TSNE
# tsne = TSNE(n_components=2, random_state=0)
# z_tsne = tsne.fit_transform(test_mu.cpu().numpy()) # Assuming test_mu contains means from encoder
# plt.figure(figsize=(10, 8))
# plt.scatter(z_tsne[:, 0], z_tsne[:, 1], c=test_labels.cpu().numpy(), cmap='tab10', s=5)
# plt.colorbar()
# plt.title('t-SNE of Latent Space')
# plt.xlabel('t-SNE dimension 1')
# plt.ylabel('t-SNE dimension 2')
# plt.show()
This hands-on exercise should give you a practical feel for how VAEs are built and what to look for during training. The diagnostics discussed here are fundamental. As you move to more advanced VAE architectures and applications in subsequent chapters, these basic checks will remain your first line of analysis. Remember that each component. the encoder, decoder, reparameterization, and loss terms. directly maps to the mathematical framework you've learned, and understanding their interplay is key to mastering VAEs.
Was this section helpful?
© 2025 ApX Machine Learning