In this hands-on section, we'll transition from the theoretical underpinnings of Variational Autoencoders (VAEs) to a practical implementation. You've learned how VAEs differ from standard autoencoders by learning a probabilistic latent space, enabling them to generate new data samples. We'll build a VAE using PyTorch, train it on the MNIST dataset of handwritten digits, and then inspect its latent space to see how it organizes the data. Finally, we'll sample from this latent space to generate new digit images.
This exercise will solidify your understanding of:
Let's get started!
First, ensure you have PyTorch installed. If not, you can typically install it via pip:
pip install torch torchvision matplotlib numpy
Now, let's import the necessary libraries for our VAE implementation.
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
We're importing numpy
for numerical operations, matplotlib.pyplot
for plotting, scipy.stats.norm
for generating a grid from a normal distribution (useful for visualizing the manifold of generated digits), and various components from torch
and torchvision
.
The MNIST dataset is a classic in machine learning, consisting of 70,000 grayscale images of handwritten digits (0-9), each 28x28 pixels. It's ideal for VAEs because the learned 2D latent space can be easily visualized.
# Define a transform to normalize the data and flatten images
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.view(-1)) # Flatten the images
])
# Load the dataset
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
# Define DataLoader
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# Define image dimensions and latent dimension
image_size = 28
original_dim = image_size * image_size # 28 * 28 = 784
latent_dim = 2
Here, we load the MNIST data using torchvision.datasets
. We define a transform
to convert images to PyTorch tensors and flatten them into a 784-dimensional vector. We then create DataLoader
instances for efficient batching during training. We also define latent_dim = 2
, meaning our VAE will compress each image into a 2-dimensional latent vector. This low dimensionality is chosen specifically so we can easily plot and inspect the structure of the latent space.
A VAE has three main parts: the encoder, the sampling layer (implementing the reparameterization trick), and the decoder. In PyTorch, we typically define these as nn.Module
classes.
The encoder takes an input image and maps it to the parameters (mean and log-variance) of a Gaussian distribution in the latent space.
# Encoder network
class Encoder(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(Encoder, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc_mean = nn.Linear(hidden_dim, latent_dim)
self.fc_log_var = nn.Linear(hidden_dim, latent_dim)
def forward(self, x):
h = F.relu(self.fc1(x))
z_mean = self.fc_mean(h)
z_log_var = self.fc_log_var(h)
return z_mean, z_log_var
encoder = Encoder(original_dim, 256, latent_dim)
print(encoder)
Our encoder is a simple feedforward neural network. It takes the flattened 784-dimensional image, passes it through a dense layer with 256 units and ReLU activation, and then has two output layers: one for z_mean
and one for z_log_var
. Both of these output layers have latent_dim
units, corresponding to the dimensions of our chosen latent space.
To train the VAE using backpropagation, we need a way to sample from the distribution q(z∣x) (defined by zmean and zlog_var) without breaking the gradient flow. The reparameterization trick achieves this: z=zmean+exp(0.5⋅zlog_var)⋅ϵ, where ϵ is sampled from a standard normal distribution N(0,I).
# Sampling function (reparameterization trick)
def sampling(z_mean, z_log_var):
std = torch.exp(0.5 * z_log_var)
epsilon = torch.randn_like(std) # Sample from standard normal
return z_mean + std * epsilon
The sampling
function takes the z_mean
and z_log_var
tensors as input. It calculates the standard deviation from z_log_var
, generates epsilon
from a standard normal distribution with the same shape, and then applies the reparameterization formula.
The decoder takes a point z sampled from the latent space and maps it back to the original data space, attempting to reconstruct the input image.
# Decoder network
class Decoder(nn.Module):
def __init__(self, latent_dim, hidden_dim, output_dim):
super(Decoder, self).__init__()
self.fc1 = nn.Linear(latent_dim, hidden_dim)
self.fc_out = nn.Linear(hidden_dim, output_dim)
def forward(self, z):
h = F.relu(self.fc1(z))
reconstruction = torch.sigmoid(self.fc_out(h)) # Sigmoid for pixel values [0,1]
return reconstruction
decoder = Decoder(latent_dim, 256, original_dim)
print(decoder)
The decoder mirrors the encoder's structure to some extent. It takes a 2D latent vector, passes it through a dense layer with 256 units (ReLU activation), and then outputs a 784-dimensional vector. We use a sigmoid activation function in the output layer because we want to reconstruct pixel values that are normalized between 0 and 1.
Now, let's connect these components to form the complete VAE model.
# VAE Model
class VAE(nn.Module):
def __init__(self, encoder, decoder):
super(VAE, self).__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, x):
z_mean, z_log_var = self.encoder(x)
z = sampling(z_mean, z_log_var)
reconstruction = self.decoder(z)
return reconstruction, z_mean, z_log_var
vae = VAE(encoder, decoder)
print(vae)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae.to(device)
The VAE
class combines the encoder
and decoder
. Its forward
method takes the input x
, passes it through the encoder to get z_mean
and z_log_var
, samples z
using the sampling
function, and finally passes z
through the decoder to get the reconstruction
. We also set the device to cuda
if available for GPU acceleration.
The VAE loss function has two parts:
We'll define a function to calculate this combined loss.
# VAE loss function
def vae_loss_function(reconstruction, x, z_mean, z_log_var):
# Reconstruction loss (Binary Cross-Entropy)
# We use F.binary_cross_entropy, reduction='sum' to match the TensorFlow scaling.
# It averages over the batch by default, so we need to sum over dimensions and then sum over batch.
reconstruction_loss = F.binary_cross_entropy(reconstruction, x, reduction='sum')
# KL Divergence loss
kl_loss = -0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp())
return reconstruction_loss + kl_loss
Here, reconstruction_loss
is calculated using F.binary_cross_entropy
between the reconstruction
and the original x
. We use reduction='sum'
to sum over all elements, effectively scaling it. The kl_loss
is calculated using the formula above.
With the model and loss function defined, we can set up the optimizer and train the VAE.
# Optimizer
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
# Training loop
epochs = 30 # You might need more epochs for better results
train_losses = []
val_losses = []
for epoch in range(epochs):
# Training
vae.train()
total_train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
reconstruction, z_mean, z_log_var = vae(data)
loss = vae_loss_function(reconstruction, data, z_mean, z_log_var)
loss.backward()
optimizer.step()
total_train_loss += loss.item()
avg_train_loss = total_train_loss / len(train_dataset)
train_losses.append(avg_train_loss)
# Validation
vae.eval()
total_val_loss = 0
with torch.no_grad():
for data, _ in test_loader:
data = data.to(device)
reconstruction, z_mean, z_log_var = vae(data)
loss = vae_loss_function(reconstruction, data, z_mean, z_log_var)
total_val_loss += loss.item()
avg_val_loss = total_val_loss / len(test_dataset)
val_losses.append(avg_val_loss)
print(f'Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')
# Plot training & validation loss values
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(loc='upper right')
plt.show()
We use the Adam optimizer for training. The training loop iterates through epochs and batches. For each batch, we perform a forward pass, calculate the VAE loss, backpropagate, and update the model parameters. We also include a validation step to monitor performance on unseen data. It's good practice to monitor the loss to see if the model is learning. For a production model, you'd likely train for more epochs.
One of the most insightful aspects of VAEs is examining the structure of their latent space. Since we chose latent_dim = 2
, we can create a 2D scatter plot of the MNIST digits in this space.
def plot_latent_space(encoder_model, data_loader, n_samples=10000):
encoder_model.eval() # Set encoder to evaluation mode
z_means = []
labels = []
with torch.no_grad():
for i, (data, label) in enumerate(data_loader):
if len(z_means) * data.shape[0] >= n_samples:
break # Limit samples for visualization
data = data.to(device)
z_mean, _ = encoder_model(data)
z_means.append(z_mean.cpu().numpy())
labels.append(label.cpu().numpy())
z_means = np.concatenate(z_means, axis=0)[:n_samples]
labels = np.concatenate(labels, axis=0)[:n_samples]
plt.figure(figsize=(12, 10))
plt.scatter(z_means[:, 0], z_means[:, 1], c=labels, cmap='viridis') # Use 'viridis' or another distinct colormap
plt.colorbar(label='Digit Label')
plt.xlabel("Latent Dimension 1 ($z_1$)")
plt.ylabel("Latent Dimension 2 ($z_2$)")
plt.title("MNIST Test Data in 2D Latent Space (Mean Values)")
plt.grid(True)
plt.show()
# Use the trained encoder model part to get z_mean
plot_latent_space(vae.encoder, test_loader)
The plot_latent_space
function uses the encoder
part of our trained VAE to get the zmean vectors for the test images. It then creates a scatter plot where each point is an image, its position determined by its 2D latent representation, and its color by its actual digit label (labels
).
You should observe that the VAE has learned to organize the digits in a somewhat structured way. Digits that look similar (e.g., 1s and 7s, or 3s and 8s) might be closer together, and there might be clear clusters for different digits. The KL divergence term in the loss encourages this continuous and organized structure.
Here's an example of what a Plotly chart for such a visualization might look like (with fewer data points for brevity in this example JSON). In practice, you'd use the full z_mean
and y_data_labels
from your plot_latent_space
function.
A scatter plot showing a sample of MNIST test images projected into the 2D latent space. Each point represents an image, colored by its digit label. This helps visualize how the VAE has organized different digits.
The real magic of VAEs as generative models comes alive when we sample points from the latent space and use the decoder to transform them into new images. Since the latent space is encouraged to be continuous, nearby points in the latent space should decode to visually similar images.
def plot_generated_images_manifold(decoder_model, n=15, figure_size=15, latent_dim_val=2):
decoder_model.eval() # Set decoder to evaluation mode
# Display a 2D manifold of digits
# We will sample points from a Gaussian grid in the latent space
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
# Linearly spaced coordinates corresponding to the percentiles of the standard normal distribution
grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))
with torch.no_grad():
for i, yi in enumerate(grid_x):
for j, xi in enumerate(grid_y):
if latent_dim_val == 2:
z_sample = torch.tensor([[xi, yi]], dtype=torch.float32).to(device)
else: # For higher latent_dim, just take first two for visualization
# Or generate random samples for other dimensions
z_sample_base = torch.randn(1, latent_dim_val).to(device)
z_sample_base[0,0] = xi
z_sample_base[0,1] = yi
z_sample = z_sample_base
x_decoded = decoder_model(z_sample)
digit = x_decoded[0].cpu().numpy().reshape(digit_size, digit_size)
figure[i * digit_size: (i + 1) * digit_size,
j * digit_size: (j + 1) * digit_size] = digit
plt.figure(figsize=(figure_size, figure_size))
plt.imshow(figure, cmap='Greys_r')
ax = plt.gca()
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel("Latent Dimension 1 variation")
ax.set_ylabel("Latent Dimension 2 variation")
plt.title("Manifold of Generated Digits from Latent Space")
plt.show()
# Use the trained decoder model part
plot_generated_images_manifold(vae.decoder, n=20, latent_dim_val=latent_dim)
The plot_generated_images_manifold
function creates a grid of points in the 2D latent space (sampling from regions that are likely under a Gaussian prior). For each point (zsample), it uses the decoder
to generate an image. These images are then arranged into a large grid and displayed.
You should see a smooth transition between different types of digits as you move across the latent space. For example, a digit might gradually morph from a '1' to a '7' or from a '4' to a '9'. This demonstrates the VAE's ability to learn a meaningful and continuous representation.
In this hands-on section, you've successfully built, trained, and inspected a Variational Autoencoder using PyTorch. You saw how to:
The features learned by the VAE's encoder, specifically the zmean vectors, can be valuable for downstream tasks. As discussed earlier in this chapter ("Using VAE Latent Representations as Features"), these compressed, structured representations can often improve the performance of classifiers or other machine learning models, especially when dealing with high-dimensional data.
Experiment further by:
latent_dim
. What happens if it's larger? Or just 1?sklearn
's LogisticRegression
on the MNIST latent features) and comparing its performance to a classifier trained on raw pixels.This practical experience forms a solid foundation for applying VAEs to more complex problems and understanding their role in both feature extraction and generative modeling.
Was this section helpful?
© 2025 ApX Machine Learning