Having established the theoretical foundations of Variational Autoencoders, including the probabilistic encoder/decoder structure, the reparameterization trick, and the Evidence Lower Bound (ELBO) objective, we can now translate this theory into practice. This section guides you through building and training a VAE using a popular deep learning framework to generate new images, specifically focusing on the MNIST dataset of handwritten digits. This practical exercise reinforces the concepts learned and demonstrates the generative capabilities of VAEs.
We'll assume you have a working Python environment with TensorFlow and Keras installed, along with standard libraries like NumPy and Matplotlib.
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
A VAE consists of three main parts: the encoder network, the sampling mechanism using the reparameterization trick, and the decoder network.
The encoder, often denoted as qϕ(z∣x), takes an input image x and maps it to the parameters of a probability distribution in the latent space. For VAEs, this is typically a multivariate Gaussian distribution with a diagonal covariance matrix. Therefore, the encoder outputs two vectors: the mean μ and the log-variance logσ2 of this distribution. Using log-variance improves numerical stability during training.
Let's define a convolutional encoder suitable for MNIST images (28x28 grayscale).
latent_dim = 2 # Using 2 dimensions for easy visualization
encoder_inputs = keras.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var], name="encoder")
encoder.summary()
The encoder takes a 28x28x1 image, processes it through convolutional and dense layers, and outputs the z_mean
and z_log_var
vectors, each of size latent_dim
.
To sample from the distribution qϕ(z∣x) defined by μ and logσ2 in a way that allows gradients to flow back through the sampling process, we use the reparameterization trick: z=μ+σ⊙ϵ, where ϵ is sampled from a standard normal distribution N(0,I). We can implement this as a custom Keras layer.
class Sampling(layers.Layer):
"""Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
def call(self, inputs):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
# Use tf.exp(0.5 * z_log_var) to get sigma
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
This layer takes z_mean
and z_log_var
as input and outputs samples z.
The decoder, pθ(x∣z), takes a point z from the latent space and maps it back to the data space, attempting to reconstruct the original input or generate a new, similar sample. Since our input is an image, the decoder will use transposed convolutional layers to upsample the latent vector back into a 28x28x1 image. The final activation is typically sigmoid for pixel values normalized between 0 and 1.
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(7 * 7 * 64, activation="relu")(latent_inputs)
x = layers.Reshape((7, 7, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
# Final layer reconstructs the image, use sigmoid for pixel probabilities [0, 1]
decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()
Now, we connect the encoder, sampling layer, and decoder to form the end-to-end VAE model. We define a custom Keras Model class to handle the custom training step involving the ELBO loss calculation.
A diagram illustrating the flow of data through the VAE architecture during training, including the calculation of the loss components.
class VAE(keras.Model):
def __init__(self, encoder, decoder, **kwargs):
super(VAE, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
self.sampling = Sampling()
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = keras.metrics.Mean(
name="reconstruction_loss"
)
self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
@property
def metrics(self):
return [
self.total_loss_tracker,
self.reconstruction_loss_tracker,
self.kl_loss_tracker,
]
def train_step(self, data):
with tf.GradientTape() as tape:
# Encode input to get mean and log variance
z_mean, z_log_var = self.encoder(data)
# Sample from the latent distribution
z = self.sampling([z_mean, z_log_var])
# Decode the latent sample to reconstruct the input
reconstruction = self.decoder(z)
# Calculate reconstruction loss (Binary Cross-Entropy for MNIST)
# Ensure input data is flattened for loss calculation matching output shape
data_flat = tf.reshape(data, [-1])
reconstruction_flat = tf.reshape(reconstruction, [-1])
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
)
)
# Calculate KL divergence loss
# D_KL(N(mu, sigma^2) || N(0, 1)) = 0.5 * sum(sigma^2 + mu^2 - 1 - log(sigma^2))
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
# Total loss is the sum of reconstruction and KL loss (negative ELBO)
total_loss = reconstruction_loss + kl_loss
# Compute gradients and update weights
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
# Update metrics
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
def call(self, inputs):
# Used for inference/prediction
z_mean, z_log_var = self.encoder(inputs)
z = self.sampling([z_mean, z_log_var])
return self.decoder(z)
In this VAE
class:
__init__
method stores the encoder and decoder and initializes metric trackers for the total loss, reconstruction loss, and KL divergence loss.train_step
method overrides the default training logic. It performs a forward pass, calculates the reconstruction loss (using binary cross-entropy suitable for sigmoid output and [0,1] normalized pixels) and the KL divergence analytically, computes the total loss (negative ELBO), and applies gradients.call
method defines the forward pass for prediction/inference, which involves encoding, sampling, and decoding.We'll use the standard MNIST dataset. The pixel values should be normalized to the [0, 1] range, which matches the sigmoid activation in the decoder's final layer.
# Load and preprocess MNIST dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255
# Instantiate the VAE model
vae = VAE(encoder, decoder)
# Compile the model
vae.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-3))
# Train the VAE
history = vae.fit(mnist_digits, epochs=30, batch_size=128) # 30 epochs is illustrative
# Plot training loss
plt.figure(figsize=(10, 5))
plt.plot(history.history['loss'], label='Total Loss')
plt.plot(history.history['reconstruction_loss'], label='Reconstruction Loss')
plt.plot(history.history['kl_loss'], label='KL Loss')
plt.title('VAE Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()
During training, you'll observe the reconstruction loss decreasing as the VAE learns to reproduce the input digits, while the KL loss encourages the latent distribution qϕ(z∣x) to stay close to the standard normal prior p(z). The balance between these two terms is essential for good generative performance.
The power of a VAE lies in its ability to generate new data. We can achieve this by sampling points z from the prior distribution (a standard Gaussian in our case, N(0,I)) and passing them through the trained decoder network pθ(x∣z).
def plot_latent_samples(vae, n=15, figsize=15):
# display a n*n 2D manifold of digits
digit_size = 28
scale = 1.0
figure = np.zeros((digit_size * n, digit_size * n))
# Sample points linearly spaced on the grid, from N(0,I) boundaries
# Use percent point function (ppf) of norm for smoother coverage
from scipy.stats import norm
grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))[::-1] # Reverse y-axis
for i, yi in enumerate(grid_y):
for j, xi in enumerate(grid_x):
# Sample a latent vector z from the grid point
z_sample = np.array([[xi, yi]]) * scale
# Decode z to generate an image x_decoded
x_decoded = vae.decoder.predict(z_sample, verbose=0)
# Reshape and place the digit in the figure
digit = x_decoded[0].reshape(digit_size, digit_size)
figure[
i * digit_size : (i + 1) * digit_size,
j * digit_size : (j + 1) * digit_size,
] = digit
plt.figure(figsize=(figsize, figsize))
start_range = digit_size // 2
end_range = n * digit_size - start_range
pixel_range = np.arange(start_range, end_range, digit_size)
sample_range_x = np.round(grid_x, 1)
sample_range_y = np.round(grid_y, 1)
plt.xticks(pixel_range, sample_range_x)
plt.yticks(pixel_range, sample_range_y)
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.imshow(figure, cmap="Greys_r")
plt.title('Digits Generated from Latent Space Samples')
plt.show()
plot_latent_samples(vae)
Running plot_latent_samples
will generate a grid of digits. Since we used latent_dim=2
, we can sample points on a 2D grid within the approximate range of the prior distribution and visualize the corresponding generated digits. You should observe a smooth transition between different digit styles as you move across the latent space, demonstrating that the VAE has learned a meaningful representation.
We can also visualize how the training data is organized in the learned latent space by encoding the MNIST test set and plotting the resulting z vectors, colored by their actual digit label. This helps understand the structure captured by the encoder.
def plot_label_clusters(vae, data, labels):
# display a 2D plot of the digit classes in the latent space
z_mean, _ = vae.encoder.predict(data, verbose=0)
plt.figure(figsize=(12, 10))
scatter = plt.scatter(z_mean[:, 0], z_mean[:, 1], c=labels, cmap='tab10', alpha=0.7, s=5)
plt.colorbar(scatter, label='Digit Class')
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.title('MNIST Test Set in VAE Latent Space (z_mean)')
plt.grid(True)
plt.show()
# Use a subset for faster plotting
num_samples_plot = 5000
plot_label_clusters(vae, x_test[:num_samples_plot], y_test[:num_samples_plot])
# Example Plotly chart (use hardcoded data for consistency if needed)
z_mean_plot, _ = vae.encoder.predict(x_test[:num_samples_plot], verbose=0)
labels_plot = y_test[:num_samples_plot]
plotly_fig = {
"data": [
{
"x": z_mean_plot[:, 0].tolist(),
"y": z_mean_plot[:, 1].tolist(),
"mode": "markers",
"marker": {
"color": labels_plot.tolist(),
"size": 5,
"opacity": 0.7,
"colorscale": "Viridis", # Example colorscale
"colorbar": {"title": "Digit Class"}
},
"type": "scatter"
}
],
"layout": {
"title": "VAE Latent Space (z_mean) for MNIST Test Set",
"xaxis": {"title": "z[0]"},
"yaxis": {"title": "z[1]"},
"width": 700,
"height": 600
}
}
print("```plotly") # Start code block marker
import json
print(json.dumps(plotly_fig)) # Print the JSON string
print("```") # End code block marker
The 2D latent space (showing the mean μ of the approximate posterior qϕ(z∣x)) for a sample of MNIST test digits, colored by class. Note how digits of the same class tend to cluster together, and how the space exhibits some structure related to digit similarity. (Using sample data for illustration).
Ideally, the plot will show clusters corresponding to different digits, indicating that the encoder has learned to map similar digits to nearby locations in the latent space. The structure might not be perfectly separated, especially with a small latent dimension and limited training, but the general organization should be apparent.
This hands-on implementation demonstrates the core components and training procedure of a VAE for image generation. You've seen how to define the encoder, decoder, and sampling layer, implement the ELBO loss, train the model, and use it to generate new data and visualize the learned representation space. This forms a solid basis for exploring more advanced VAE variants and applications discussed later.
© 2025 ApX Machine Learning