Okay, let's put the theory from the previous sections into practice. Univariate drift detection, while useful, often misses shifts in the relationships between features. Multivariate drift detection methods address this by considering the data distribution as a whole. In this hands-on exercise, we will implement a multivariate drift detection mechanism using the Maximum Mean Discrepancy (MMD) test, a powerful non-parametric method for comparing distributions. We'll use the alibi-detect
library, which provides convenient implementations of various drift detection algorithms.
First, ensure you have the necessary libraries installed. If not, you can install them using pip:
pip install alibi-detect numpy pandas scikit-learn plotly
We will use NumPy for numerical operations, Pandas for potential data manipulation (though less needed here), Scikit-learn for generating synthetic data, and Plotly for visualization.
import numpy as np
import plotly.graph_objects as go
from sklearn.datasets import make_blobs
from alibi_detect.cd import MMDDrift
print("Libraries imported successfully.")
To demonstrate multivariate drift, we need two datasets: a reference dataset representing the "normal" state (e.g., training data or data from an initial, stable production period) and a new dataset representing potentially drifted data.
Let's create synthetic 5-dimensional data. The reference data will consist of two clusters. The drifted data will have one cluster shifted, representing a change in the underlying data generating process.
# Reference data (representing stable state)
np.random.seed(0)
n_features = 5
n_samples_ref = 500
X_ref, _ = make_blobs(n_samples=n_samples_ref, n_features=n_features, centers=[np.zeros(n_features), np.ones(n_features)*1.5], cluster_std=0.5, random_state=0)
print(f"Reference data shape: {X_ref.shape}")
# Drifted data (representing changed state)
np.random.seed(1)
n_samples_drift = 500
# Shift one of the centers slightly and increase variance
centers_drift = [np.zeros(n_features) + 0.2, np.ones(n_features)*1.5 + 0.3]
X_drift, _ = make_blobs(n_samples=n_samples_drift, n_features=n_features, centers=centers_drift, cluster_std=0.6, random_state=1)
print(f"Drifted data shape: {X_drift.shape}")
Now, we instantiate the MMDDrift
detector from alibi-detect
. The key arguments are:
X_ref
: The reference dataset. The detector learns the "normal" distribution from this data.p_val
: The significance level for the statistical test. A common value is 0.05. If the computed p-value falls below this threshold, drift is detected.backend
: Specifies the computation backend ('tensorflow', 'pytorch', or 'keops'). We'll use 'tensorflow' here, but feel free to adapt if you prefer PyTorch. Ensure TensorFlow is installed (pip install tensorflow
).preprocess_fn
: Optionally, a function to preprocess data before drift detection (e.g., scaling). We'll skip this for simplicity here but it's important in practice.# Configure the MMD drift detector
p_threshold = 0.05 # Significance level
try:
# Requires tensorflow installation
# pip install tensorflow
cd = MMDDrift(X_ref, backend='tensorflow', p_val=p_threshold)
print("MMDDrift detector initialized successfully.")
except ImportError:
print("TensorFlow backend not available. Please install tensorflow.")
cd = None
except Exception as e:
print(f"Error initializing detector: {e}")
cd = None
With the detector configured, we can now test for drift on the new dataset (X_drift
). The predict
method returns a dictionary containing:
data['is_drift']
: 1 if drift is detected, 0 otherwise.data['p_val']
: The computed p-value from the MMD test.data['distance']
: The MMD statistic (a measure of distance between the distributions).data['threshold']
: The critical threshold for the MMD statistic based on the reference data permutations and p_val
. Drift is detected if distance > threshold
.# Perform drift detection on the new data
if cd:
preds = cd.predict(X_drift)
# Print results
print("\nDrift Detection Results:")
print(f" Is drift detected? {'Yes' if preds['data']['is_drift'] else 'No'}")
print(f" p-value: {preds['data']['p_val']:.4f}")
print(f" MMD distance: {preds['data']['distance']:.4f}")
print(f" Distance threshold: {preds['data']['threshold']:.4f}")
if preds['data']['is_drift']:
print("\nAnalysis: Drift detected! The distribution of the new data is statistically different from the reference data.")
else:
print("\nAnalysis: No significant drift detected.")
else:
print("\nSkipping detection as detector initialization failed.")
You should see that drift is detected because the p-value is below our threshold of 0.05 and the calculated MMD distance exceeds the threshold derived from permutations on the reference data.
In a real system, data arrives continuously or in batches. Let's simulate this by generating small batches of data, some reflecting the original distribution and some the drifted one, and observe how the MMD statistic behaves.
# Simulate data arriving in batches over time
n_batches = 20
batch_size = 100
drift_scores = []
p_values = []
detection_threshold = None # Store the threshold from the detector
time_steps = list(range(n_batches))
if cd:
detection_threshold = cd.threshold # Get the pre-computed threshold
for i in range(n_batches):
np.random.seed(i * 10) # Ensure reproducibility for each batch
# Introduce drift after 10 batches
if i < 10:
# Sample from reference distribution
centers_batch = [np.zeros(n_features), np.ones(n_features)*1.5]
std_batch = 0.5
else:
# Sample from drifted distribution
centers_batch = centers_drift
std_batch = 0.6
X_batch, _ = make_blobs(n_samples=batch_size, n_features=n_features, centers=centers_batch, cluster_std=std_batch, random_state=i*10)
# Check for drift in the current batch
batch_preds = cd.predict(X_batch)
drift_scores.append(batch_preds['data']['distance'])
p_values.append(batch_preds['data']['p_val'])
print(f"\nSimulated {n_batches} batches. Threshold: {detection_threshold:.4f}")
Now, let's visualize the MMD distance (drift score) for each batch compared to the detection threshold.
The MMD distance between incoming data batches and the reference data. After batch 10, the underlying data distribution changes, causing the MMD distance to significantly increase and cross the pre-defined detection threshold (dashed red line), indicating drift.
As expected, the MMD distance remains low for the first 10 batches (sampled from the original distribution) and then jumps significantly above the threshold when the data starts coming from the drifted distribution.
This practical exercise demonstrated how to:
MMDDrift
detector from alibi-detect
.MMD is effective because it compares distributions in a Reproducing Kernel Hilbert Space (RKHS), allowing it to capture complex, non-linear differences. However, be mindful of:
alibi-detect
uses optimizations like sampling permutations for the threshold calculation.X_ref
) is important for accurately estimating the null distribution (no drift) and setting a reliable threshold.This hands-on example provides a starting point. In production systems, you would integrate such detectors into your MLOps pipeline, triggering alerts or automated actions (like retraining) when significant drift is confirmed. Remember to adapt the data generation, detector parameters, and simulation logic to match your specific use case.
© 2025 ApX Machine Learning