Now that we understand the theoretical underpinnings of Shapley values and how SHAP connects them to machine learning features, let's turn to the practical implementation using the popular shap
Python library. This library provides efficient implementations of various SHAP algorithms, including KernelSHAP and TreeSHAP, along with powerful visualization tools.
We'll walk through the process of generating SHAP explanations for a sample model. As established in the prerequisite, familiarity with Python and basic machine learning workflows using libraries like scikit-learn
and pandas
is assumed.
First, ensure you have the necessary libraries installed. You'll primarily need shap
, scikit-learn
, pandas
, and numpy
. If you haven't installed them yet, you can typically do so using pip:
pip install shap scikit-learn pandas numpy matplotlib
Now, let's import the required modules in our Python script or notebook:
import shap
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
# Optional: Initialize JS visualization support for environments like Jupyter
# shap.initjs() # Uncomment this line if using Jupyter notebooks/lab
The shap.initjs()
call is needed for rendering SHAP plots in certain interactive environments like Jupyter notebooks.
To demonstrate SHAP, we need a trained machine learning model. Let's create a synthetic dataset and train a RandomForestClassifier
. Tree-based models are a good starting point because the shap
library includes TreeExplainer
, a highly optimized method for these types of models, as discussed previously.
# Generate synthetic classification data
X, y = make_classification(n_samples=1000, n_features=10, n_informative=5,
n_redundant=2, n_classes=2, random_state=42)
feature_names = [f'feature_{i}' for i in range(X.shape[1])]
X_df = pd.DataFrame(X, columns=feature_names)
# Split data
X_train, X_test, y_train, y_test = train_test_split(X_df, y, test_size=0.2, random_state=42)
# Train a RandomForestClassifier model
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
print(f"Model Accuracy: {model.score(X_test, y_test):.4f}")
We now have a trained RandomForestClassifier
ready for explanation.
The shap
library provides different "explainers" optimized for various model types. Since we trained a Random Forest, the most efficient choice is shap.TreeExplainer
. If we were working with a model for which a specialized explainer isn't available (like a complex neural network or an SVM with a non-linear kernel), we could use the model-agnostic shap.KernelExplainer
.
Let's instantiate TreeExplainer
:
# Create a TreeExplainer object
explainer = shap.TreeExplainer(model)
Creating the explainer object often involves passing the trained model. For some explainers like KernelExplainer
, you might also need to provide a background dataset (usually a sample of the training data) to represent the expected distribution of feature values. TreeExplainer
derives expectations directly from the tree structures and training data distribution, making its initialization simpler.
With the explainer ready, we can calculate SHAP values for any set of instances we want to explain. Typically, you might explain predictions on your test set or specific instances of interest. The shap_values
method computes these values.
# Calculate SHAP values for the test set
# For TreeExplainer, shap_values often returns a list (one per output class)
# or a single array for regression. For binary classification with scikit-learn,
# it often returns a list of two arrays [shap_values_class_0, shap_values_class_1].
# We are usually interested in the SHAP values for the positive class (class 1).
shap_values = explainer.shap_values(X_test)
# If shap_values is a list (common for classification), select values for class 1
if isinstance(shap_values, list):
shap_values_class1 = shap_values[1]
else:
# Handle regression or other cases where output might be a single array
shap_values_class1 = shap_values
# The explainer also provides the expected value (base value)
# This is the average model prediction over the training data (or background data)
expected_value = explainer.expected_value
# For classification, expected_value might be a list [E[y] for class 0, E[y] for class 1]
if isinstance(expected_value, list):
expected_value_class1 = expected_value[1]
else:
expected_value_class1 = expected_value
print(f"SHAP values shape: {shap_values_class1.shape}") # Should be (n_samples, n_features)
print(f"Expected value (base value) for class 1: {expected_value_class1}")
The shap_values
variable (here shap_values_class1
) holds a NumPy array where each row corresponds to an instance in X_test
, and each column corresponds to a feature. The value shap_values[i, j]
represents the contribution of feature j
to the prediction for instance i
, pushing it away from the base value (expected_value
). The expected_value
represents the average prediction output over the dataset the explainer was trained or calibrated on. For classification models, SHAP values are typically calculated for the model's output before the final activation function (e.g., logits or probabilities). TreeExplainer
often computes them for the margin output.
Remember the additive property: for any single prediction i, the sum of the SHAP values for all its features plus the base value equals the model's output for that prediction: model_output(xi)≈expected_value+∑j=1Mshap_value(xi,j) where M is the number of features.
The shap
library excels at visualization, making it easier to interpret the computed SHAP values.
Force plots are effective for visualizing the explanation of a single prediction. They show which features pushed the model's output higher (positive SHAP value, typically shown in red) and which pushed it lower (negative SHAP value, typically shown in blue).
# Explain the prediction for the first instance in the test set
instance_index = 0
shap.force_plot(expected_value_class1,
shap_values_class1[instance_index, :],
X_test.iloc[instance_index, :],
matplotlib=True) # Use matplotlib=True for static plots if needed
A force plot showing feature contributions for a single prediction. Features pushing the prediction towards the positive class (value > base value) are in red, while features pushing towards the negative class are in blue. The length of the bar indicates the magnitude of the feature's impact.
You can also create a force plot for multiple instances, rotated and stacked vertically:
# Visualize explanations for the first 100 test instances
shap.force_plot(expected_value_class1,
shap_values_class1[:100, :],
X_test.iloc[:100, :])
A stacked force plot visualizing SHAP values for multiple instances (here, the first 100 from the test set). Each instance is a row, allowing comparison of feature impacts across samples. Instances are often ordered by similarity.
Summary plots provide a global overview of feature importance by combining the SHAP values across all instances. The default summary_plot
shows each feature's importance (mean absolute SHAP value) and the distribution of SHAP values for that feature.
# Create a summary plot for overall feature importance
shap.summary_plot(shap_values_class1, X_test, plot_type="dot")
SHAP Summary Plot (dot version). Features are ranked by their global importance (mean absolute SHAP value). Each point represents a SHAP value for a specific instance and feature. The horizontal position shows the SHAP value (impact), and the color indicates the original feature value (high or low). This reveals not just which features are important, but how their values affect the prediction. For example, high values of
feature_0
seem to strongly increase the prediction towards class 1.
Alternatively, a bar chart summary plot simply shows the mean absolute SHAP value per feature, giving a straightforward ranking of global feature importance.
# Create a bar summary plot
shap.summary_plot(shap_values_class1, X_test, plot_type="bar")
SHAP Summary Plot (bar version). This plot aggregates the impact by showing the average absolute SHAP value for each feature across all instances provided. It gives a clear ranking of features by their overall contribution to the model's predictions.
Dependence plots show how the model's output dependency on a single feature changes across its value range. They plot the value of a feature against its corresponding SHAP value for all instances. The coloring can optionally indicate an interaction effect with another feature.
# Create a dependence plot for 'feature_0'
# Optionally color by 'feature_3' to see interaction effects
shap.dependence_plot("feature_0", shap_values_class1, X_test, interaction_index="feature_3")
# Create a dependence plot for 'feature_4' without explicit interaction coloring
# shap.dependence_plot("feature_4", shap_values_class1, X_test, interaction_index=None)
SHAP Dependence Plot showing the relationship between the value of
feature_0
and its SHAP value across all test samples. The color indicates the value offeature_3
, revealing potential interaction effects. Here, higher values offeature_0
generally lead to higher SHAP values (increasing the likelihood of class 1). The coloring might suggest this effect is modulated byfeature_3
.
These plots provide a powerful toolkit for moving from the theoretical concept of SHAP values to actionable insights about your model's behavior, both for individual predictions and overall trends. The next section provides a hands-on practical exercise to solidify these implementation steps.
© 2025 ApX Machine Learning