<НА ГЛАВНУЮ

Осваиваем SHAP-IQ Визуализации: Раскрываем Внутреннюю Логику Моделей Машинного Обучения

Узнайте, как визуализации SHAP-IQ помогают понять предсказания моделей машинного обучения, разбивая вклад признаков и их взаимодействия с помощью наглядных инструментов.

Исследование Визуализаций SHAP-IQ для Интерпретации Моделей

Визуализации SHAP-IQ предоставляют мощные инструменты для понимания того, как модели машинного обучения приходят к своим предсказаниям, разбивая сложное поведение на понятные компоненты. Они показывают вклад отдельных признаков и их взаимодействия, помогая ясно интерпретировать решения модели.

Настройка Окружения

Установите необходимые зависимости:

!pip install shapiq overrides scikit-learn pandas numpy seaborn

Импортируйте библиотеки и проверьте версию shapiq:

from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
from tqdm.asyncio import tqdm
import shapiq
 
print(f"shapiq version: {shapiq.__version__}")

Загрузка и Подготовка Данных

Используем набор данных MPG (мили на галлон) из библиотеки Seaborn с информацией о характеристиках автомобилей.

import seaborn as sns
df = sns.load_dataset("mpg")
df

Обрабатываем данные: удаляем пропуски и кодируем категориальные признаки:

import pandas as pd
from sklearn.preprocessing import LabelEncoder
 
df = df.dropna()
le = LabelEncoder()
df.loc[:, "origin"] = le.fit_transform(df["origin"])
 
for i, label in enumerate(le.classes_):
    print(f"{label}{i}")

Разделение Данных и Обучение Модели

Выделяем признаки и целевую переменную, делим на обучающую и тестовую выборки:

X = df.drop(columns=["mpg", "name"])
y = df["mpg"]
 
feature_names = X.columns.tolist()
x_data, y_data = X.values, y.values
 
x_train, x_test, y_train, y_test = train_test_split(x_data, y_data, test_size=0.2, random_state=42)

Обучаем регрессор случайного леса:

model = RandomForestRegressor(random_state=42, max_depth=10, n_estimators=10)
model.fit(x_train, y_train)

Оценка Качества Модели

Вычисляем среднеквадратичную ошибку и коэффициент детерминации:

mse = mean_squared_error(y_test, model.predict(x_test))
r2 = r2_score(y_test, model.predict(x_test))
print(f"Mean Squared Error: {mse:.2f}")
print(f"R2 Score: {r2:.2f}")

Объяснение Отдельного Примера

Рассматриваем, как модель предсказала значение для конкретного примера (instance_id = 7):

instance_id = 7
x_explain = x_test[instance_id]
y_true = y_test[instance_id]
y_pred = model.predict(x_explain.reshape(1, -1))[0]
print(f"Instance {instance_id}, True Value: {y_true}, Predicted Value: {y_pred}")
for i, feature in enumerate(feature_names):
    print(f"{feature}: {x_explain[i]}")

Генерация SHAP-Объяснений

Создаем объяснения для разных порядков взаимодействия (отдельные признаки, парные взаимодействия, полные взаимодействия):

feature_names = list(X.columns)
n_features = len(feature_names)
 
si_order = {}
for order in tqdm([1, 2, n_features]):
    index = "k-SII" if order > 1 else "SV"
    explainer = shapiq.TreeExplainer(model=model, max_order=order, index=index)
    si_order[order] = explainer.explain(x=x_explain)
si_order

Визуализация Результатов

1. Force Chart Показывает, как каждый признак влияет на предсказание относительно базового значения. Красные бары увеличивают предсказание, синие — уменьшают. Также отображаются взаимодействия признаков.

sv = si_order[1]
si = si_order[2]
mi = si_order[n_features]
 
sv.plot_force(feature_names=feature_names, show=True)
si.plot_force(feature_names=feature_names, show=True)
mi.plot_force(feature_names=feature_names, show=True)

2. Waterfall Chart Похож на force chart, но группирует малозначимые признаки в категорию «прочее» для удобства восприятия.

sv.plot_waterfall(feature_names=feature_names, show=True)
si.plot_waterfall(feature_names=feature_names, show=True)
mi.plot_waterfall(feature_names=feature_names, show=True)

3. Network Plot Отображает взаимодействия первого и второго порядка между признаками с размером узлов и толщиной/цветом ребер, отражающими влияние и силу взаимодействия.

si.plot_network(feature_names=feature_names, show=True)
mi.plot_network(feature_names=feature_names, show=True)

4. SI Graph Plot Визуализирует все высокопорядковые взаимодействия в виде гиперребер между несколькими признаками, давая полное представление о совместном влиянии.

abbrev_feature_names = shapiq.plot.utils.abbreviate_feature_names(feature_names)
sv.plot_si_graph(feature_names=abbrev_feature_names, show=True, size_factor=2.5, node_size_scaling=1.5, plot_original_nodes=True)
si.plot_si_graph(feature_names=abbrev_feature_names, show=True, size_factor=2.5, node_size_scaling=1.5, plot_original_nodes=True)
mi.plot_si_graph(feature_names=abbrev_feature_names, show=True, size_factor=2.5, node_size_scaling=1.5, plot_original_nodes=True)

5. Bar Plot Используется для глобальных объяснений, суммируя среднюю важность признаков и взаимодействий по всем примерам.

explanations = []
explainer = shapiq.TreeExplainer(model=model, max_order=2, index="k-SII")
for instance_id in tqdm(range(20)):
    x_explain = x_test[instance_id]
    si = explainer.explain(x=x_explain)
    explanations.append(si)
shapiq.plot.bar_plot(explanations, feature_names=feature_names, show=True)

Признаки "Distance" и "Horsepower" оказывают наибольшее влияние, а взаимодействия "Horsepower × Weight" и "Distance × Horsepower" показывают значимые совместные эффекты, указывая на наличие нелинейных зависимостей в модели.

Полный код и дополнительные материалы доступны на GitHub. Следите за обновлениями в Twitter, присоединяйтесь к сообществу на Reddit и подписывайтесь на рассылку.

🇬🇧

Switch Language

Read this article in English

Switch to English