Often in data analysis, you need to compare different views of your data or display multiple related plots together. While creating separate figures for each plot works, it's often more effective to arrange them within a single figure. Matplotlib provides a powerful mechanism for this using subplots. This approach allows for direct visual comparison and creates more organized, comprehensive visualizations.
plt.subplots()
ApproachThe recommended way to create a figure containing multiple subplots is using the plt.subplots()
function (notice the 's' at the end). This function creates a figure and a grid of subplots (axes) simultaneously and returns both objects.
import matplotlib.pyplot as plt
import numpy as np
# Create some sample data
x = np.linspace(0, 2 * np.pi, 100)
y_sin = np.sin(x)
y_cos = np.cos(x)
y_tan = np.tan(x)
random_data = np.random.randn(1000)
# Create a figure and a 2x2 grid of subplots
# fig is the entire figure window
# axes is a 2D NumPy array containing the individual subplot (Axes) objects
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 8))
print(axes.shape)
# Output: (2, 2)
# axes[0, 0] refers to the subplot in the first row, first column
# axes[0, 1] refers to the subplot in the first row, second column
# etc.
plt.subplots()
takes several arguments, the most common being nrows
and ncols
to specify the grid dimensions. The figsize
argument controls the overall size of the figure in inches. The crucial part is that it returns two objects:
fig
: A matplotlib.figure.Figure
object, representing the entire window or page on which everything is drawn.axes
: An array (usually a NumPy array) of matplotlib.axes._axes.Axes
objects. Each Axes
object represents a single subplot in the grid and has its own plotting methods.If you create a single row or single column of subplots (e.g., plt.subplots(3, 1)
or plt.subplots(1, 3)
), axes
will be a 1D NumPy array.
Once you have the axes
array, you can plot data onto specific subplots by calling plotting methods directly on the corresponding Axes
object, rather than using the global plt
functions (plt.plot()
, plt.hist()
, etc.). You access the individual Axes
objects using standard array indexing.
# Plot on the top-left subplot (index [0, 0])
axes[0, 0].plot(x, y_sin, color='#4263eb')
axes[0, 0].set_title('Sine Wave')
axes[0, 0].set_xlabel('Radians')
axes[0, 0].set_ylabel('Value')
axes[0, 0].grid(True, linestyle='--', alpha=0.6)
# Plot on the top-right subplot (index [0, 1])
axes[0, 1].plot(x, y_cos, color='#12b886')
axes[0, 1].set_title('Cosine Wave')
axes[0, 1].set_xlabel('Radians')
# We might omit the y-label here if the axes are shared or the context is clear
# Plot on the bottom-left subplot (index [1, 0])
axes[1, 0].scatter(x, y_sin + np.random.normal(0, 0.1, 100), color='#f06595', alpha=0.6, s=10) # Scatter plot
axes[1, 0].set_title('Noisy Sine Scatter')
axes[1, 0].set_xlabel('Radians')
axes[1, 0].set_ylabel('Value')
# Plot on the bottom-right subplot (index [1, 1])
axes[1, 1].hist(random_data, bins=30, color='#fd7e14', alpha=0.7)
axes[1, 1].set_title('Histogram of Random Data')
axes[1, 1].set_xlabel('Value')
axes[1, 1].set_ylabel('Frequency')
# Display the figure with all subplots
plt.show() # This command is generally needed outside interactive environments like Jupyter
Notice how methods like plot()
, set_title()
, set_xlabel()
, set_ylabel()
, and grid()
are called on the specific axes
object (e.g., axes[0, 0].plot(...)
). This object-oriented approach gives you precise control over each subplot.
When comparing plots, it's often beneficial to have them share the same x-axis or y-axis limits and ticks. This makes direct comparison much easier. plt.subplots()
provides convenient arguments for this: sharex
and sharey
.
Setting sharex=True
means all subplots will share the same x-axis. Zooming or panning on one plot's x-axis will affect all others. Similarly, sharey=True
links the y-axes. You can also set them to 'col'
or 'row'
to share axes only within the same column or row, respectively.
# Create a 2x1 grid where subplots share the x-axis
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(8, 6), sharex=True)
axes[0].plot(x, y_sin, color='#7048e8')
axes[0].set_title('Sine Wave')
axes[0].set_ylabel('Value')
axes[0].grid(True, linestyle='--', alpha=0.6)
axes[1].plot(x, y_cos, color='#0ca678')
axes[1].set_title('Cosine Wave')
axes[1].set_xlabel('Radians')
axes[1].set_ylabel('Value')
axes[1].grid(True, linestyle='--', alpha=0.6)
# Notice only the bottom plot needs an x-label since they are shared.
# Matplotlib automatically hides the x-tick labels on the top plot.
plt.show()
When axes are shared, Matplotlib is often smart enough to hide redundant tick labels (like the x-axis labels on the top plot in the example above), leading to a cleaner look.
Sometimes, titles, labels, or tick labels from different subplots can overlap, making the figure difficult to read. Matplotlib provides a simple function to automatically adjust subplot parameters to give a tight layout.
# Using the 2x2 example from before:
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 8))
axes[0, 0].plot(x, y_sin, color='#4263eb')
axes[0, 0].set_title('Sine Wave')
# ... add other plots and labels as before ...
axes[0, 1].plot(x, y_cos, color='#12b886')
axes[0, 1].set_title('Cosine Wave')
axes[1, 0].scatter(x, y_sin + np.random.normal(0, 0.1, 100), color='#f06595', alpha=0.6, s=10)
axes[1, 0].set_title('Noisy Sine Scatter')
axes[1, 1].hist(random_data, bins=30, color='#fd7e14', alpha=0.7)
axes[1, 1].set_title('Histogram of Random Data')
# Add tight_layout() call AFTER plotting and setting labels/titles
plt.tight_layout()
plt.show()
Calling plt.tight_layout()
(or fig.tight_layout()
) after all plotting commands usually resolves most overlapping issues by adjusting the spacing between subplots. For more fine-grained control over spacing, you can explore the plt.subplots_adjust()
function, though tight_layout()
is sufficient for many common cases.
Working with subplots is fundamental for creating informative dashboards and comparative visualizations in data analysis and machine learning. By mastering the plt.subplots()
function and the object-oriented plotting approach using Axes
objects, you gain significant control over the layout and appearance of complex figures.
© 2025 ApX Machine Learning