<RETURN_TO_BASE

Mastering SHAP-IQ Visualizations: Unveiling Machine Learning Model Insights

Discover how SHAP-IQ visualizations help interpret machine learning models by breaking down feature contributions and interactions with detailed visual tools.

Exploring SHAP-IQ Visualizations for Model Interpretability

SHAP-IQ visualizations offer powerful insights into how machine learning models generate predictions by breaking down complex behaviors into understandable components. These visualizations reveal both individual feature contributions and interactions between features, helping users interpret model decisions clearly.

Setting Up the Environment

To start, install the necessary dependencies:

!pip install shapiq overrides scikit-learn pandas numpy seaborn

Import essential libraries and verify the shapiq version:

from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
from tqdm.asyncio import tqdm
import shapiq
 
print(f"shapiq version: {shapiq.__version__}")

Loading and Preparing the Dataset

We utilize the MPG (Miles Per Gallon) dataset from Seaborn, which contains features such as horsepower, weight, and origin for various car models.

import seaborn as sns
df = sns.load_dataset("mpg")
df

Preprocess the data by dropping missing values and encoding categorical variables:

import pandas as pd
from sklearn.preprocessing import LabelEncoder
 
df = df.dropna()
le = LabelEncoder()
df.loc[:, "origin"] = le.fit_transform(df["origin"])
 
for i, label in enumerate(le.classes_):
    print(f"{label}{i}")

Splitting Data and Training the Model

Separate features and target variable, then split into training and test sets:

X = df.drop(columns=["mpg", "name"])
y = df["mpg"]
 
feature_names = X.columns.tolist()
x_data, y_data = X.values, y.values
 
x_train, x_test, y_train, y_test = train_test_split(x_data, y_data, test_size=0.2, random_state=42)

Train a Random Forest Regressor:

model = RandomForestRegressor(random_state=42, max_depth=10, n_estimators=10)
model.fit(x_train, y_train)

Evaluating Model Performance

Calculate Mean Squared Error and R2 Score:

mse = mean_squared_error(y_test, model.predict(x_test))
r2 = r2_score(y_test, model.predict(x_test))
print(f"Mean Squared Error: {mse:.2f}")
print(f"R2 Score: {r2:.2f}")

Explaining a Single Prediction

Analyze how the model predicted for a specific instance (instance_id = 7):

instance_id = 7
x_explain = x_test[instance_id]
y_true = y_test[instance_id]
y_pred = model.predict(x_explain.reshape(1, -1))[0]
print(f"Instance {instance_id}, True Value: {y_true}, Predicted Value: {y_pred}")
for i, feature in enumerate(feature_names):
    print(f"{feature}: {x_explain[i]}")

Generating Shapley-based Explanations

Create explanations for different interaction orders (individual features, pairwise interactions, and full interactions):

feature_names = list(X.columns)
n_features = len(feature_names)
 
si_order = {}
for order in tqdm([1, 2, n_features]):
    index = "k-SII" if order > 1 else "SV"
    explainer = shapiq.TreeExplainer(model=model, max_order=order, index=index)
    si_order[order] = explainer.explain(x=x_explain)
si_order

Visualization Techniques

1. Force Chart This chart shows how each feature pushes the prediction above or below the baseline. Red bars increase the prediction, blue bars decrease it. It also visualizes feature interactions.

sv = si_order[1]
si = si_order[2]
mi = si_order[n_features]
 
sv.plot_force(feature_names=feature_names, show=True)
si.plot_force(feature_names=feature_names, show=True)
mi.plot_force(feature_names=feature_names, show=True)

2. Waterfall Chart Similar to the force chart, but groups minor feature contributions into an 'other' category for clarity.

sv.plot_waterfall(feature_names=feature_names, show=True)
si.plot_waterfall(feature_names=feature_names, show=True)
mi.plot_waterfall(feature_names=feature_names, show=True)

3. Network Plot Displays first- and second-order feature interactions with node size and edge width/color representing impact and interaction strength.

si.plot_network(feature_names=feature_names, show=True)
mi.plot_network(feature_names=feature_names, show=True)

4. SI Graph Plot Visualizes all higher-order interactions as hyper-edges connecting multiple features, offering a comprehensive view of joint influences.

abbrev_feature_names = shapiq.plot.utils.abbreviate_feature_names(feature_names)
sv.plot_si_graph(feature_names=abbrev_feature_names, show=True, size_factor=2.5, node_size_scaling=1.5, plot_original_nodes=True)
si.plot_si_graph(feature_names=abbrev_feature_names, show=True, size_factor=2.5, node_size_scaling=1.5, plot_original_nodes=True)
mi.plot_si_graph(feature_names=abbrev_feature_names, show=True, size_factor=2.5, node_size_scaling=1.5, plot_original_nodes=True)

5. Bar Plot Used for global explanations, summarizing the overall feature importance and interaction strengths across instances.

explanations = []
explainer = shapiq.TreeExplainer(model=model, max_order=2, index="k-SII")
for instance_id in tqdm(range(20)):
    x_explain = x_test[instance_id]
    si = explainer.explain(x=x_explain)
    explanations.append(si)
shapiq.plot.bar_plot(explanations, feature_names=feature_names, show=True)

Features like "Distance" and "Horsepower" have the largest average impact, while interactions such as "Horsepower × Weight" show significant combined effects, highlighting non-linear relationships in the model.

For full code and more tutorials, visit the GitHub page and follow on Twitter. Join the ML community on Reddit and subscribe to the newsletter for updates.

🇷🇺

Сменить язык

Читать эту статью на русском

Переключить на Русский