A practical implementation of Variational Autoencoders (VAEs) is explored. VAEs learn a probabilistic latent space, differentiating them from standard autoencoders and enabling the generation of new data samples. We will build a VAE using PyTorch, train it on the MNIST dataset of handwritten digits, and then inspect its latent space to see how it organizes the data. Finally, sampling from this latent space will generate new digit images.This exercise will solidify your understanding of:Constructing the encoder to output mean ($z_{mean}$) and log-variance ($z_{log_var}$).Implementing the reparameterization trick for sampling.Defining the VAE-specific loss function (reconstruction + KL divergence).Visualizing the learned latent space.Generating new data by sampling from the latent space.Let's get started!1. Setting Up Your Environment and Importing LibrariesFirst, ensure you have PyTorch installed. If not, you can typically install it via pip: pip install torch torchvision matplotlib numpyNow, let's import the necessary libraries for our VAE implementation.import numpy as np import matplotlib.pyplot as plt from scipy.stats import norm import torch import torch.nn as nn import torch.nn.functional as F from torchvision import datasets, transforms from torch.utils.data import DataLoaderWe're importing numpy for numerical operations, matplotlib.pyplot for plotting, scipy.stats.norm for generating a grid from a normal distribution (useful for visualizing the manifold of generated digits), and various components from torch and torchvision.2. Loading and Preprocessing the MNIST DatasetThe MNIST dataset is a classic in machine learning, consisting of 70,000 grayscale images of handwritten digits (0-9), each 28x28 pixels. It's ideal for VAEs because the learned 2D latent space can be easily visualized.# Define a transform to normalize the data and flatten images transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1)) # Flatten the images ]) # Load the dataset train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform) test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform) # Define DataLoader batch_size = 128 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # Define image dimensions and latent dimension image_size = 28 original_dim = image_size * image_size # 28 * 28 = 784 latent_dim = 2Here, we load the MNIST data using torchvision.datasets. We define a transform to convert images to PyTorch tensors and flatten them into a 784-dimensional vector. We then create DataLoader instances for efficient batching during training. We also define latent_dim = 2, meaning our VAE will compress each image into a 2-dimensional latent vector. This low dimensionality is chosen specifically so we can easily plot and inspect the structure of the latent space.3. Building the VAE ComponentsA VAE has three main parts: the encoder, the sampling layer (implementing the reparameterization trick), and the decoder. In PyTorch, we typically define these as nn.Module classes.The Encoder NetworkThe encoder takes an input image and maps it to the parameters (mean and log-variance) of a Gaussian distribution in the latent space.# Encoder network class Encoder(nn.Module): def __init__(self, input_dim, hidden_dim, latent_dim): super(Encoder, self).__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.fc_mean = nn.Linear(hidden_dim, latent_dim) self.fc_log_var = nn.Linear(hidden_dim, latent_dim) def forward(self, x): h = F.relu(self.fc1(x)) z_mean = self.fc_mean(h) z_log_var = self.fc_log_var(h) return z_mean, z_log_var encoder = Encoder(original_dim, 256, latent_dim) print(encoder)Our encoder is a simple feedforward neural network. It takes the flattened 784-dimensional image, passes it through a dense layer with 256 units and ReLU activation, and then has two output layers: one for z_mean and one for z_log_var. Both of these output layers have latent_dim units, corresponding to the dimensions of our chosen latent space.The Sampling Layer (Reparameterization Trick)To train the VAE using backpropagation, we need a way to sample from the distribution $q(z|x)$ (defined by $z_{mean}$ and $z_{log_var}$) without breaking the gradient flow. The reparameterization trick achieves this: $z = z_{mean} + \exp(0.5 \cdot z_{log_var}) \cdot \epsilon$, where $\epsilon$ is sampled from a standard normal distribution $\mathcal{N}(0, I)$.# Sampling function (reparameterization trick) def sampling(z_mean, z_log_var): std = torch.exp(0.5 * z_log_var) epsilon = torch.randn_like(std) # Sample from standard normal return z_mean + std * epsilonThe sampling function takes the z_mean and z_log_var tensors as input. It calculates the standard deviation from z_log_var, generates epsilon from a standard normal distribution with the same shape, and then applies the reparameterization formula.The Decoder NetworkThe decoder takes a point $z$ sampled from the latent space and maps it back to the original data space, attempting to reconstruct the input image.# Decoder network class Decoder(nn.Module): def __init__(self, latent_dim, hidden_dim, output_dim): super(Decoder, self).__init__() self.fc1 = nn.Linear(latent_dim, hidden_dim) self.fc_out = nn.Linear(hidden_dim, output_dim) def forward(self, z): h = F.relu(self.fc1(z)) reconstruction = torch.sigmoid(self.fc_out(h)) # Sigmoid for pixel values [0,1] return reconstruction decoder = Decoder(latent_dim, 256, original_dim) print(decoder)The decoder mirrors the encoder's structure to some extent. It takes a 2D latent vector, passes it through a dense layer with 256 units (ReLU activation), and then outputs a 784-dimensional vector. We use a sigmoid activation function in the output layer because we want to reconstruct pixel values that are normalized between 0 and 1.The Full VAE ModelNow, let's connect these components to form the complete VAE model.# VAE Model class VAE(nn.Module): def __init__(self, encoder, decoder): super(VAE, self).__init__() self.encoder = encoder self.decoder = decoder def forward(self, x): z_mean, z_log_var = self.encoder(x) z = sampling(z_mean, z_log_var) reconstruction = self.decoder(z) return reconstruction, z_mean, z_log_var vae = VAE(encoder, decoder) print(vae) # Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") vae.to(device)The VAE class combines the encoder and decoder. Its forward method takes the input x, passes it through the encoder to get z_mean and z_log_var, samples z using the sampling function, and finally passes z through the decoder to get the reconstruction. We also set the device to cuda if available for GPU acceleration.4. Defining the VAE Loss FunctionThe VAE loss function has two parts:Reconstruction Loss: This measures how well the decoder reconstructs the input image. For MNIST, since pixel values are normalized and can be treated as probabilities, binary cross-entropy (BCE) is a common choice.KL Divergence Loss: This acts as a regularizer, pushing the distribution $q(z|x)$ learned by the encoder to be close to a prior distribution $p(z)$ (typically a standard normal distribution $\mathcal{N}(0, I)$). The formula is: $$D_{KL}(q(z|x) || p(z)) = -0.5 \cdot \sum_{j=1}^{latent_dim} (1 + z_{log_var_j} - z_{mean_j}^2 - \exp(z_{log_var_j}))$$We'll define a function to calculate this combined loss.# VAE loss function def vae_loss_function(reconstruction, x, z_mean, z_log_var): # Reconstruction loss (Binary Cross-Entropy) # We use F.binary_cross_entropy, reduction='sum' to match the TensorFlow scaling. # It averages over the batch by default, so we need to sum over dimensions and then sum over batch. reconstruction_loss = F.binary_cross_entropy(reconstruction, x, reduction='sum') # KL Divergence loss kl_loss = -0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp()) return reconstruction_loss + kl_lossHere, reconstruction_loss is calculated using F.binary_cross_entropy between the reconstruction and the original x. We use reduction='sum' to sum over all elements, effectively scaling it. The kl_loss is calculated using the formula above.5. Compiling and Training the VAEWith the model and loss function defined, we can set up the optimizer and train the VAE.# Optimizer optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3) # Training loop epochs = 30 # You might need more epochs for better results train_losses = [] val_losses = [] for epoch in range(epochs): # Training vae.train() total_train_loss = 0 for batch_idx, (data, _) in enumerate(train_loader): data = data.to(device) optimizer.zero_grad() reconstruction, z_mean, z_log_var = vae(data) loss = vae_loss_function(reconstruction, data, z_mean, z_log_var) loss.backward() optimizer.step() total_train_loss += loss.item() avg_train_loss = total_train_loss / len(train_dataset) train_losses.append(avg_train_loss) # Validation vae.eval() total_val_loss = 0 with torch.no_grad(): for data, _ in test_loader: data = data.to(device) reconstruction, z_mean, z_log_var = vae(data) loss = vae_loss_function(reconstruction, data, z_mean, z_log_var) total_val_loss += loss.item() avg_val_loss = total_val_loss / len(test_dataset) val_losses.append(avg_val_loss) print(f'Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}') # Plot training & validation loss values plt.figure(figsize=(10, 5)) plt.plot(train_losses, label='Train Loss') plt.plot(val_losses, label='Validation Loss') plt.title('Model Loss') plt.ylabel('Loss') plt.xlabel('Epoch') plt.legend(loc='upper right') plt.show()We use the Adam optimizer for training. The training loop iterates through epochs and batches. For each batch, we perform a forward pass, calculate the VAE loss, backpropagate, and update the model parameters. We also include a validation step to monitor performance on unseen data. It's good practice to monitor the loss to see if the model is learning. For a production model, you'd likely train for more epochs.6. Inspecting the Latent SpaceOne of the most insightful aspects of VAEs is examining the structure of their latent space. Since we chose latent_dim = 2, we can create a 2D scatter plot of the MNIST digits in this space.def plot_latent_space(encoder_model, data_loader, n_samples=10000): encoder_model.eval() # Set encoder to evaluation mode z_means = [] labels = [] with torch.no_grad(): for i, (data, label) in enumerate(data_loader): if len(z_means) * data.shape[0] >= n_samples: break # Limit samples for visualization data = data.to(device) z_mean, _ = encoder_model(data) z_means.append(z_mean.cpu().numpy()) labels.append(label.cpu().numpy()) z_means = np.concatenate(z_means, axis=0)[:n_samples] labels = np.concatenate(labels, axis=0)[:n_samples] plt.figure(figsize=(12, 10)) plt.scatter(z_means[:, 0], z_means[:, 1], c=labels, cmap='viridis') # Use 'viridis' or another distinct colormap plt.colorbar(label='Digit Label') plt.xlabel("Latent Dimension 1 ($z_1$)") plt.ylabel("Latent Dimension 2 ($z_2$)") plt.title("MNIST Test Data in 2D Latent Space (Mean Values)") plt.grid(True) plt.show() # Use the trained encoder model part to get z_mean plot_latent_space(vae.encoder, test_loader)The plot_latent_space function uses the encoder part of our trained VAE to get the $z_{mean}$ vectors for the test images. It then creates a scatter plot where each point is an image, its position determined by its 2D latent representation, and its color by its actual digit label (labels).You should observe that the VAE has learned to organize the digits in a somewhat structured way. Digits that look similar (e.g., 1s and 7s, or 3s and 8s) might be closer together, and there might be clear clusters for different digits. The KL divergence term in the loss encourages this continuous and organized structure.Here's an example of what a Plotly chart for such a visualization might look like (with fewer data points for brevity in this example JSON). In practice, you'd use the full z_mean and y_data_labels from your plot_latent_space function.{"layout": {"title": "Sample MNIST Latent Space (z_mean)", "xaxis": {"title": "Latent Dim 1"}, "yaxis": {"title": "Latent Dim 2"}, "height": 500, "width": 600}, "data": [{"type": "scatter", "mode": "markers", "x": [-1.5, -1.2, 1.8, 2.1, 0.1, -0.2, 0.5, 0.7, -2.0, -2.3], "y": [2.0, 2.3, -1.0, -1.2, -0.5, -0.8, 2.5, 2.2, 0.3, 0.1], "marker": {"color": [0, 0, 1, 1, 2, 2, 7, 7, 8, 8], "colorscale": "Viridis", "showscale": true, "colorbar": {"title": "Digit"}}}]}A scatter plot showing a sample of MNIST test images projected into the 2D latent space. Each point represents an image, colored by its digit label. This helps visualize how the VAE has organized different digits.7. Generating New Samples (Sampling from the Latent Space)The real magic of VAEs as generative models comes alive when we sample points from the latent space and use the decoder to transform them into new images. Since the latent space is encouraged to be continuous, nearby points in the latent space should decode to visually similar images.def plot_generated_images_manifold(decoder_model, n=15, figure_size=15, latent_dim_val=2): decoder_model.eval() # Set decoder to evaluation mode # Display a 2D manifold of digits # We will sample points from a Gaussian grid in the latent space digit_size = 28 figure = np.zeros((digit_size * n, digit_size * n)) # Linearly spaced coordinates corresponding to the percentiles of the standard normal distribution grid_x = norm.ppf(np.linspace(0.05, 0.95, n)) grid_y = norm.ppf(np.linspace(0.05, 0.95, n)) with torch.no_grad(): for i, yi in enumerate(grid_x): for j, xi in enumerate(grid_y): if latent_dim_val == 2: z_sample = torch.tensor([[xi, yi]], dtype=torch.float32).to(device) else: # For higher latent_dim, just take first two for visualization # Or generate random samples for other dimensions z_sample_base = torch.randn(1, latent_dim_val).to(device) z_sample_base[0,0] = xi z_sample_base[0,1] = yi z_sample = z_sample_base x_decoded = decoder_model(z_sample) digit = x_decoded[0].cpu().numpy().reshape(digit_size, digit_size) figure[i * digit_size: (i + 1) * digit_size, j * digit_size: (j + 1) * digit_size] = digit plt.figure(figsize=(figure_size, figure_size)) plt.imshow(figure, cmap='Greys_r') ax = plt.gca() ax.set_xticks([]) ax.set_yticks([]) ax.set_xlabel("Latent Dimension 1 variation") ax.set_ylabel("Latent Dimension 2 variation") plt.title("Manifold of Generated Digits from Latent Space") plt.show() # Use the trained decoder model part plot_generated_images_manifold(vae.decoder, n=20, latent_dim_val=latent_dim)The plot_generated_images_manifold function creates a grid of points in the 2D latent space (sampling from regions that are likely under a Gaussian prior). For each point ($z_{sample}$), it uses the decoder to generate an image. These images are then arranged into a large grid and displayed.You should see a smooth transition between different types of digits as you move across the latent space. For example, a digit might gradually morph from a '1' to a '7' or from a '4' to a '9'. This demonstrates the VAE's ability to learn a meaningful and continuous representation.8. Summary and What's NextIn this hands-on section, you've successfully built, trained, and inspected a Variational Autoencoder using PyTorch. You saw how to:Define an encoder that outputs distributional parameters ($z_{mean}, z_{log_var}$).Implement the reparameterization trick for sampling.Create a custom VAE loss function combining reconstruction and KL divergence terms.Train the VAE and visualize its learned latent space, observing how it clusters and organizes data like MNIST digits.Use the decoder to generate new data samples by traversing the latent space, showcasing the VAE's generative capabilities.The features learned by the VAE's encoder, specifically the $z_{mean}$ vectors, can be valuable for downstream tasks. As discussed earlier in this chapter ("Using VAE Latent Representations as Features"), these compressed, structured representations can often improve the performance of classifiers or other machine learning models, especially when dealing with high-dimensional data.Experiment further by:Trying different network architectures for the encoder and decoder.Adjusting the latent_dim. What happens if it's larger? Or just 1?Training for more epochs or on different datasets.Applying the extracted latent features ($z_{mean}$) to a simple classification task (e.g., using sklearn's LogisticRegression on the MNIST latent features) and comparing its performance to a classifier trained on raw pixels.This practical experience forms a solid foundation for applying VAEs to more complex problems and understanding their role in both feature extraction and generative modeling.