After successfully training your autoencoder, the next rewarding step is to visually inspect its performance. While quantitative metrics like Mean Squared Error (MSE), which you learned about when configuring the model, give you a single number to represent the overall reconstruction error, seeing the results with your own eyes provides a different, often more intuitive, understanding of how well your autoencoder is working. This section will guide you through the practical steps of visualizing the original input data alongside the data reconstructed by your autoencoder.
The main job of our basic autoencoder is to reconstruct its input as accurately as possible. By comparing the original images to the ones generated by the autoencoder after compression and decompression, we can qualitatively assess:
For datasets like MNIST, which consists of handwritten digits, this means we'll be looking at how well the autoencoder can "redraw" a digit after having seen it.
Before we can visualize, we need two sets of images:
Assuming you have your trained autoencoder
model and a set of test_images
(properly preprocessed, just like the training data), you can get the reconstructed images using the model's predict
method. This method takes the input data, passes it through the encoder and then the decoder, and gives you the final output.
# Assuming 'autoencoder' is your trained Keras model
# and 'test_images' is your preprocessed test dataset (e.g., MNIST)
# Get the reconstructed images
reconstructed_images = autoencoder.predict(test_images)
The reconstructed_images
variable will now hold the autoencoder's attempt at recreating the test_images
.
We'll use a popular Python library called matplotlib
to display these images. If you're working in a Jupyter Notebook or a similar environment, matplotlib
can render images directly in your output.
Here's a Python snippet that displays the first n
original images from your test set and their corresponding reconstructions in two rows:
import matplotlib.pyplot as plt
# Number of digits to display
n = 10 # You can change this to display more or fewer images
plt.figure(figsize=(20, 4)) # Adjust figure size as needed
for i in range(n):
# --- Display original images ---
# Create a subplot for the original image
ax_original = plt.subplot(2, n, i + 1)
# Display the image. MNIST images are 28x28 pixels.
# We reshape the flattened image array back to its 2D shape.
plt.imshow(test_images[i].reshape(28, 28), cmap='gray')
# Remove x and y axis ticks and labels for a cleaner look
ax_original.get_xaxis().set_visible(False)
ax_original.get_yaxis().set_visible(False)
# Add a title to the first original image subplot
if i == 0:
ax_original.set_title("Original Images")
# --- Display reconstructed images ---
# Create a subplot for the reconstructed image
ax_reconstructed = plt.subplot(2, n, i + 1 + n)
# Display the reconstructed image, also reshaped.
plt.imshow(reconstructed_images[i].reshape(28, 28), cmap='gray')
# Remove x and y axis ticks and labels
ax_reconstructed.get_xaxis().set_visible(False)
ax_reconstructed.get_yaxis().set_visible(False)
# Add a title to the first reconstructed image subplot
if i == 0:
ax_reconstructed.set_title("Reconstructed Images")
plt.show() # This command renders the plot
Let's break down what this code does:
import matplotlib.pyplot as plt
: This line imports the plotting functionality.n = 10
: We decide to display 10 images.plt.figure(figsize=(20, 4))
: This creates a new figure (like a canvas) for our plots. figsize
controls its width and height in inches.for
loop iterates n
times, once for each image we want to display.plt.subplot(2, n, i + 1)
: This command creates a grid of subplots. 2
means two rows, n
means n
columns. i + 1
is the index of the current subplot in the top row (for original images).plt.imshow(test_images[i].reshape(28, 28), cmap='gray')
: This is the core command for displaying an image.
test_images[i]
selects the i-th image from our test set..reshape(28, 28)
converts the flattened image data (if it was, for example, a 784-element array for a 28x28 image) back into a 2D array (28 rows, 28 columns), which imshow
expects for grayscale images.cmap='gray'
tells matplotlib
to display the image in grayscale.ax.get_xaxis().set_visible(False)
and ax.get_yaxis().set_visible(False)
: These lines hide the numbered axes, as they are not very informative for viewing these small images.ax.set_title(...)
: Adds a title above the first image in each row for clarity.plt.subplot(2, n, i + 1 + n)
creates a subplot in the second row (for reconstructed images).plt.show()
: Finally, this displays the entire figure with all its subplots.When you run the code, you should see two rows of images. The top row will show the original handwritten digits, and the bottom row will show the digits as reconstructed by your autoencoder.
Look closely at the pairs:
These visual checks are very valuable. They tell you how well your autoencoder has learned to capture the essential features of the data needed for reconstruction. If the reconstructions are very poor (e.g., unrecognizable blobs), it might indicate issues with your model architecture (too small a bottleneck, not enough layers or neurons), the amount of training, or the learning rate. If they are very good, it shows your autoencoder is effectively learning a useful, compressed representation.
This hands-on visualization step bridges the gap between the abstract loss numbers and a tangible understanding of your autoencoder's behavior. It's an important part of the iterative process of building and refining machine learning models. In the next section, we'll take a closer look at the data in its compressed form by examining the output of the encoder, often called the "encoded data" or "latent space representation."
Was this section helpful?
© 2025 ApX Machine Learning