Осваиваем SHAP-IQ Визуализации: Раскрываем Внутреннюю Логику Моделей Машинного Обучения
Узнайте, как визуализации SHAP-IQ помогают понять предсказания моделей машинного обучения, разбивая вклад признаков и их взаимодействия с помощью наглядных инструментов.
Исследование Визуализаций SHAP-IQ для Интерпретации Моделей
Визуализации SHAP-IQ предоставляют мощные инструменты для понимания того, как модели машинного обучения приходят к своим предсказаниям, разбивая сложное поведение на понятные компоненты. Они показывают вклад отдельных признаков и их взаимодействия, помогая ясно интерпретировать решения модели.
Настройка Окружения
Установите необходимые зависимости:
!pip install shapiq overrides scikit-learn pandas numpy seabornИмпортируйте библиотеки и проверьте версию shapiq:
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
from tqdm.asyncio import tqdm
import shapiq
print(f"shapiq version: {shapiq.__version__}")Загрузка и Подготовка Данных
Используем набор данных MPG (мили на галлон) из библиотеки Seaborn с информацией о характеристиках автомобилей.
import seaborn as sns
df = sns.load_dataset("mpg")
dfОбрабатываем данные: удаляем пропуски и кодируем категориальные признаки:
import pandas as pd
from sklearn.preprocessing import LabelEncoder
df = df.dropna()
le = LabelEncoder()
df.loc[:, "origin"] = le.fit_transform(df["origin"])
for i, label in enumerate(le.classes_):
print(f"{label} → {i}")Разделение Данных и Обучение Модели
Выделяем признаки и целевую переменную, делим на обучающую и тестовую выборки:
X = df.drop(columns=["mpg", "name"])
y = df["mpg"]
feature_names = X.columns.tolist()
x_data, y_data = X.values, y.values
x_train, x_test, y_train, y_test = train_test_split(x_data, y_data, test_size=0.2, random_state=42)Обучаем регрессор случайного леса:
model = RandomForestRegressor(random_state=42, max_depth=10, n_estimators=10)
model.fit(x_train, y_train)Оценка Качества Модели
Вычисляем среднеквадратичную ошибку и коэффициент детерминации:
mse = mean_squared_error(y_test, model.predict(x_test))
r2 = r2_score(y_test, model.predict(x_test))
print(f"Mean Squared Error: {mse:.2f}")
print(f"R2 Score: {r2:.2f}")Объяснение Отдельного Примера
Рассматриваем, как модель предсказала значение для конкретного примера (instance_id = 7):
instance_id = 7
x_explain = x_test[instance_id]
y_true = y_test[instance_id]
y_pred = model.predict(x_explain.reshape(1, -1))[0]
print(f"Instance {instance_id}, True Value: {y_true}, Predicted Value: {y_pred}")
for i, feature in enumerate(feature_names):
print(f"{feature}: {x_explain[i]}")Генерация SHAP-Объяснений
Создаем объяснения для разных порядков взаимодействия (отдельные признаки, парные взаимодействия, полные взаимодействия):
feature_names = list(X.columns)
n_features = len(feature_names)
si_order = {}
for order in tqdm([1, 2, n_features]):
index = "k-SII" if order > 1 else "SV"
explainer = shapiq.TreeExplainer(model=model, max_order=order, index=index)
si_order[order] = explainer.explain(x=x_explain)
si_orderВизуализация Результатов
1. Force Chart Показывает, как каждый признак влияет на предсказание относительно базового значения. Красные бары увеличивают предсказание, синие — уменьшают. Также отображаются взаимодействия признаков.
sv = si_order[1]
si = si_order[2]
mi = si_order[n_features]
sv.plot_force(feature_names=feature_names, show=True)
si.plot_force(feature_names=feature_names, show=True)
mi.plot_force(feature_names=feature_names, show=True)2. Waterfall Chart Похож на force chart, но группирует малозначимые признаки в категорию «прочее» для удобства восприятия.
sv.plot_waterfall(feature_names=feature_names, show=True)
si.plot_waterfall(feature_names=feature_names, show=True)
mi.plot_waterfall(feature_names=feature_names, show=True)3. Network Plot Отображает взаимодействия первого и второго порядка между признаками с размером узлов и толщиной/цветом ребер, отражающими влияние и силу взаимодействия.
si.plot_network(feature_names=feature_names, show=True)
mi.plot_network(feature_names=feature_names, show=True)4. SI Graph Plot Визуализирует все высокопорядковые взаимодействия в виде гиперребер между несколькими признаками, давая полное представление о совместном влиянии.
abbrev_feature_names = shapiq.plot.utils.abbreviate_feature_names(feature_names)
sv.plot_si_graph(feature_names=abbrev_feature_names, show=True, size_factor=2.5, node_size_scaling=1.5, plot_original_nodes=True)
si.plot_si_graph(feature_names=abbrev_feature_names, show=True, size_factor=2.5, node_size_scaling=1.5, plot_original_nodes=True)
mi.plot_si_graph(feature_names=abbrev_feature_names, show=True, size_factor=2.5, node_size_scaling=1.5, plot_original_nodes=True)5. Bar Plot Используется для глобальных объяснений, суммируя среднюю важность признаков и взаимодействий по всем примерам.
explanations = []
explainer = shapiq.TreeExplainer(model=model, max_order=2, index="k-SII")
for instance_id in tqdm(range(20)):
x_explain = x_test[instance_id]
si = explainer.explain(x=x_explain)
explanations.append(si)
shapiq.plot.bar_plot(explanations, feature_names=feature_names, show=True)Признаки "Distance" и "Horsepower" оказывают наибольшее влияние, а взаимодействия "Horsepower × Weight" и "Distance × Horsepower" показывают значимые совместные эффекты, указывая на наличие нелинейных зависимостей в модели.
Полный код и дополнительные материалы доступны на GitHub. Следите за обновлениями в Twitter, присоединяйтесь к сообществу на Reddit и подписывайтесь на рассылку.
Switch Language
Read this article in English