使用 scikit-learn 进行糖尿病回归

本示例使用与模型无关的 KernelExplainer 和 TreeExplainer 来解释在一个小型糖尿病数据集上训练的几个不同的回归模型。本笔记本旨在提供如何对各种模型使用 KernelExplainer 的示例。

加载数据

[1]:
import time

import numpy as np
from sklearn.model_selection import train_test_split

import shap

X, y = shap.datasets.diabetes()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

# rather than use the whole training set to estimate expected values, we summarize with
# a set of weighted kmeans, each weighted by the number of points they represent.
X_train_summary = shap.kmeans(X_train, 10)


def print_accuracy(f):
    print(
        f"Root mean squared test error = {np.sqrt(np.mean((f(X_test) - y_test) ** 2))}"
    )
    time.sleep(0.5)  # to let the print get out before any progress bars


shap.initjs()

线性回归

[2]:
from sklearn import linear_model

lin_regr = linear_model.LinearRegression()
lin_regr.fit(X_train, y_train)

print_accuracy(lin_regr.predict)
Root mean squared test error = 58.51766133582009

解释测试集中的单个预测

[3]:
ex = shap.KernelExplainer(lin_regr.predict, X_train_summary)
shap_values = ex.shap_values(X_test.iloc[0, :])
shap.force_plot(ex.expected_value, shap_values, X_test.iloc[0, :])
[3]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security.

解释测试集中的所有预测

[4]:
shap_values = ex.shap_values(X_test)
shap.summary_plot(shap_values, X_test)
100%|██████████| 89/89 [00:21<00:00,  4.16it/s]
../../../_images/example_notebooks_tabular_examples_model_agnostic_Diabetes_regression_8_1.png
[5]:
shap.dependence_plot("bmi", shap_values, X_test)
../../../_images/example_notebooks_tabular_examples_model_agnostic_Diabetes_regression_9_0.png
[6]:
shap.force_plot(ex.expected_value, shap_values, X_test)
[6]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security.

决策树回归器

[7]:
from sklearn import tree

dtree = tree.DecisionTreeRegressor(min_samples_split=20)
dtree.fit(X_train, y_train)
print_accuracy(dtree.predict)

# explain all the predictions in the test set
ex = shap.TreeExplainer(dtree)
shap_values = ex.shap_values(X_test)
shap.summary_plot(shap_values, X_test)
Root mean squared test error = 71.98699151013147
../../../_images/example_notebooks_tabular_examples_model_agnostic_Diabetes_regression_12_1.png
[8]:
shap.dependence_plot("bmi", shap_values, X_test)
../../../_images/example_notebooks_tabular_examples_model_agnostic_Diabetes_regression_13_0.png
[9]:
shap.force_plot(ex.expected_value, shap_values, X_test)
[9]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security.

随机森林

使用快速的 TreeExplainer 实现。

[10]:
from sklearn.ensemble import RandomForestRegressor

rforest = RandomForestRegressor(
    n_estimators=1000, max_depth=None, min_samples_split=2, random_state=0
)
rforest.fit(X_train, y_train)
print_accuracy(rforest.predict)

# explain all the predictions in the test set
explainer = shap.TreeExplainer(rforest)
shap_values = explainer.shap_values(X_test)
shap.summary_plot(shap_values, X_test)
Root mean squared test error = 61.24795842972228
../../../_images/example_notebooks_tabular_examples_model_agnostic_Diabetes_regression_16_1.png
[11]:
shap.dependence_plot("bmi", shap_values, X_test)
../../../_images/example_notebooks_tabular_examples_model_agnostic_Diabetes_regression_17_0.png
[12]:
shap.force_plot(explainer.expected_value, shap_values, X_test)
[12]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security.

神经网络

[13]:
from sklearn.neural_network import MLPRegressor

nn = MLPRegressor(solver="lbfgs", alpha=1e-1, hidden_layer_sizes=(5, 2), random_state=0)
nn.fit(X_train, y_train)
print_accuracy(nn.predict)

# explain all the predictions in the test set
explainer = shap.KernelExplainer(nn.predict, X_train_summary)
shap_values = explainer.shap_values(X_test)
shap.summary_plot(shap_values, X_test)
Root mean squared test error = 58.517105754085364
100%|██████████| 89/89 [00:19<00:00,  4.65it/s]
../../../_images/example_notebooks_tabular_examples_model_agnostic_Diabetes_regression_20_2.png
[14]:
shap.dependence_plot("bmi", shap_values, X_test)
../../../_images/example_notebooks_tabular_examples_model_agnostic_Diabetes_regression_21_0.png
[15]:
shap.force_plot(explainer.expected_value, shap_values, X_test)
[15]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security.