Now that we've covered the theoretical foundations of SHAP values, Shapley values, and the different SHAP explainers like KernelSHAP and TreeSHAP, it's time to put this knowledge into practice. This section provides hands-on examples using the shap
Python library to calculate and visualize SHAP values, helping you interpret both individual predictions and overall model behavior.
We'll assume you have a trained machine learning model and want to understand its predictions better. For this example, let's imagine we've trained a Gradient Boosting Classifier on a familiar dataset like the Iris dataset.
First, ensure you have the necessary libraries installed. You'll primarily need shap
, scikit-learn
, and pandas
.
# Install shap if you haven't already
# pip install shap
import shap
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.datasets import load_iris
# Load the Iris dataset
iris = load_iris()
X = pd.DataFrame(iris.data, columns=iris.feature_names)
y = iris.target
# For demonstration, let's simplify to a binary classification problem
# Class 0 vs Class 1 & 2
y_binary = (y > 0).astype(int)
class_names = ['setosa', 'not setosa'] # Map 0 to 'setosa', 1 to 'not setosa'
# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y_binary, test_size=0.2, random_state=42)
# Train a Gradient Boosting Classifier
model = GradientBoostingClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
print(f"Model trained on {X_train.shape[0]} samples.")
print(f"Testing on {X_test.shape[0]} samples.")
This sets up our environment with data loaded into a Pandas DataFrame X
and target variable y_binary
. We've trained a GradientBoostingClassifier
.
Since Gradient Boosting is a tree-based ensemble model, the most efficient way to calculate SHAP values is using shap.TreeExplainer
. This explainer leverages the internal structure of the trees for faster computation compared to model-agnostic methods like KernelExplainer
.
# 1. Initialize the explainer
# We pass the trained model to the explainer.
explainer = shap.TreeExplainer(model)
# 2. Calculate SHAP values
# We compute SHAP values for the test set (or any data we want to explain).
# The output 'shap_values' is typically a list (for multi-class) or array (for binary)
# containing SHAP values for each feature, for each instance.
# For binary classification with scikit-learn models, shap usually returns values
# for the positive class (class 1 in our case).
shap_values = explainer.shap_values(X_test)
# Check the shape: (number of samples, number of features)
print(f"SHAP values shape: {shap_values.shape}")
# Check the expected value (base value)
# This is the average prediction over the training data (or background data)
print(f"Explainer expected value (base value): {explainer.expected_value}")
The shap_values
variable now holds an array where each row corresponds to an instance in X_test
, and each column corresponds to a feature. The value itself represents the contribution of that feature to the model's output (logit or probability, depending on the model and explainer settings) for that specific instance, compared to the base value (explainer.expected_value
).
To understand a single prediction, the force plot is very effective. It shows features contributing to push the prediction higher (positive SHAP values, typically red) and features pushing it lower (negative SHAP values, typically blue).
Let's explain the prediction for the first instance in our test set (X_test.iloc[0]
).
# Initialize JavaScript visualization in the environment (needed for plots)
shap.initjs()
# Explain the first prediction
instance_index = 0
# Create the force plot for the first instance
# We pass the base value, the SHAP values for that instance, and the feature values
shap.force_plot(explainer.expected_value,
shap_values[instance_index,:],
X_test.iloc[instance_index,:],
matplotlib=False) # Use Javascript plot
A force plot showing the contribution of each feature to the prediction for a single instance. Features pushing the prediction higher are in red, and features pushing it lower are in blue. The length of the bar indicates the magnitude of the feature's impact.
Interpreting this plot: The "base value" represents the average prediction score across all samples. The features listed are those driving the prediction for this specific instance away from the base value. If the final output value is higher than the base value, the red features (positive SHAP values) have had a stronger collective impact than the blue features (negative SHAP values). You can see exactly which features (like petal width (cm)
or sepal length (cm)
) had the most significant influence on this particular prediction and in which direction.
While force plots are great for individual predictions, summary plots provide a global overview of feature importance across the entire dataset (or the subset you calculated SHAP values for).
There are several styles of summary plots. The default (plot_type="dot"
) shows the distribution of SHAP values for each feature.
# Create a summary plot (dot version)
shap.summary_plot(shap_values, X_test, plot_type="dot")
Distribution of SHAP values for each feature across the test set. Each point is a single prediction's SHAP value for a feature. Color indicates the feature's value (red=high, blue=low).
Interpretation:
petal length (cm)
and petal width (cm)
are the most influential.petal length (cm)
, high values (red points) generally have high positive SHAP values, pushing the prediction towards class 1 ('not setosa'). Low values (blue points) have negative SHAP values, pushing towards class 0 ('setosa'). sepal width (cm)
shows less impact overall (points clustered near zero) and less clear correlation between feature value and impact direction.Alternatively, a bar plot shows the mean absolute SHAP value for each feature, providing a clearer view of overall magnitude of importance.
# Create a summary plot (bar version)
shap.summary_plot(shap_values, X_test, plot_type="bar")
Average absolute SHAP value for each feature across the test set. Higher bars indicate greater average impact on predictions.
This bar chart confirms that petal length and width have the highest average impact on the model's predictions for this dataset.
Dependence plots show how the SHAP value for a single feature changes as the feature's value varies. They can also reveal interaction effects by coloring points based on the value of another feature.
Let's examine the dependence of the model's output on petal width (cm)
and see if it interacts with petal length (cm)
.
# Create a dependence plot for 'petal width (cm)'
# Color points by 'petal length (cm)' to check for interaction
shap.dependence_plot("petal width (cm)", shap_values, X_test, interaction_index="petal length (cm)")
{"data":[{"x":[0.2,0.2,0.2,0.4,0.2,0.1,0.2,0.2,0.1,0.2,0.4,0.2,0.1,0.2,0.6,0.4,0.3,0.2,0.2,0.2,1.5,1.5,1.6,1.5,1.3,1.6,1. ,1.7,1.4,1.5],"y":[-1.384871,-1.4847989,-1.4847989,-1.4847989,-1.4847989,-1.4847989,-1.384871,-1.4847989,-1.4847989,-1.4847989,-1.4847989,-1.4847989,-1.4847989,-1.384871,-1.4847989,-1.384871,-1.384871,-1.4847989,-1.4847989,-1.4847989,2.7968922,2.7968922,2.7968922,2.7968922,2.8268008,2.7968922,2.7968922,2.8268008,2.8268008,2.8268008],"mode":"markers","marker":{"color":[1.4,1.4,1.3,1.5,1.4,1.4,1.5,1.4,1.5,1.3,1.5,1.5,1. ,1.3,1.4,1.4,1.4,1.7,1.5,1.7,4.5,4.9,5.6,5.1,5.1,5.4,4.5,5.8,6. ,5.1],"colorbar":{"title":"petal length (cm)"},"colorscale":[[0.0,"#339af0"],[1.0,"#f03e3e"]]},"type":"scatter"}],"layout":{"xaxis":{"title":"petal width (cm)"},"yaxis":{"title":"SHAP value for petal width (cm)"},"title":{"text":"SHAP Dependence Plot: petal width (cm)"},"hovermode":"closest"}}
Relationship between the value of 'petal width (cm)' and its corresponding SHAP value. Color indicates the value of 'petal length (cm)'.
Interpretation:
petal width (cm)
increases, its SHAP value tends to increase, meaning higher petal width strongly pushes the prediction towards 'not setosa'. There's a sharp jump around a petal width of 0.7-0.8 cm.petal length (cm)
helps visualize interactions. Notice how points with high petal width (cm)
(right side) consistently have high petal length (cm)
(red color). Conversely, low petal width (cm)
points (left side) mostly have low petal length (cm)
(blue color). This strong correlation is expected in the Iris dataset. If there were a significant interaction beyond this correlation, you might see vertical separation of colors for a given petal width (e.g., at petal width=1.5, red points having systematically higher or lower SHAP values than blue points). In this case, the interaction seems minimal beyond the inherent correlation between the two features.By generating and analyzing these plots, you move from simply knowing the model's prediction to understanding why it made that prediction, identifying the most influential features both locally and globally, and exploring potential feature interactions. This practical application of SHAP provides significant value for debugging, validating, and building trust in your machine learning models.
© 2025 ApX Machine Learning