XGBoost 中的基本 SHAP 交互值示例
本笔记本展示了如何计算一个非常简单的函数的 SHAP 交互值。我们从一个简单的线性函数开始,然后添加一个交互项,看看它如何改变 SHAP 值和 SHAP 交互值。
[1]:
import numpy as np
import xgboost
from sklearn.linear_model import LinearRegression
import shap
解释一个没有交互作用的线性函数
模拟一些二进制数据和一个带有交互项的线性结果。
请注意,我们将 X
中的特征完全独立开来,以便于求解精确的 SHAP 值。
[2]:
N = 2_000
X = np.zeros((N, 5))
X[:1_000, 0] = 1
X[:500, 1] = 1
X[1_000:1_500, 1] = 1
X[:250, 2] = 1
X[500:750, 2] = 1
X[1_000:1_250, 2] = 1
X[1_500:1_750, 2] = 1
# mean-center the data
X[:, 0:3] -= 0.5
y = 2 * X[:, 0] - 3 * X[:, 1]
我们看到这些变量确实是独立的
[3]:
np.cov(X.T)
[3]:
array([[0.25012506, 0. , 0. , 0. , 0. ],
[0. , 0.25012506, 0. , 0. , 0. ],
[0. , 0. , 0.25012506, 0. , 0. ],
[0. , 0. , 0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0. ]])
并且是均值中心化的。
[4]:
X.mean(axis=0)
[4]:
array([0., 0., 0., 0., 0.])
[5]:
# train a model with single tree
Xd = xgboost.DMatrix(X, label=y)
model = xgboost.train({"eta": 1, "max_depth": 3, "base_score": 0, "lambda": 0}, Xd, 1)
print("Model error =", np.linalg.norm(y - model.predict(Xd)))
print(model.get_dump(with_stats=True)[0])
Model error = 0.0
0:[f1<0] yes=1,no=2,missing=1,gain=4500,cover=2000
1:[f0<0] yes=3,no=4,missing=3,gain=1000,cover=1000
3:leaf=0.5,cover=500
4:leaf=2.5,cover=500
2:[f0<0] yes=5,no=6,missing=5,gain=1000,cover=1000
5:leaf=-2.5,cover=500
6:leaf=-0.5,cover=500
SHAP 值
[6]:
pred = model.predict(Xd, output_margin=True)
explainer = shap.TreeExplainer(model)
explanation = explainer(Xd)
shap_values = explanation.values
# make sure the SHAP values add up to marginal predictions
np.abs(shap_values.sum(axis=1) + explanation.base_values - pred).max()
[6]:
0.0
如果我们构建一个 beeswarm 图,我们可以看到只有特征0和特征1对输出有影响,并且它们的影响只有两种可能的幅度(分别为1/-1和1.5/-1.5)。
[7]:
shap.plots.beeswarm(explanation)
训练一个线性模型
[8]:
lr = LinearRegression()
lr.fit(X, y)
lr_pred = lr.predict(X)
lr.coef_.round(2)
[8]:
array([ 2., -3., -0., 0., 0.])
确保计算的 SHAP 值与真实的 SHAP 值匹配(对于任何线性回归,我们可以直接计算真实的 SHAP 值)
[9]:
main_effect_shap_values = lr.coef_ * (X - X.mean(0))
np.linalg.norm(shap_values - main_effect_shap_values)
[9]:
1.6542433490447965e-13
SHAP 交互值
注意,当没有交互存在时,SHAP 交互值只是一个对角矩阵,对角线上是 SHAP 值。
[10]:
shap_interaction_values = explainer.shap_interaction_values(Xd)
shap_interaction_values[0]
[10]:
array([[ 1. , 0. , 0. , 0. , 0. ],
[ 0. , -1.5, 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. ]], dtype=float32)
让我们确保 SHAP 交互值的总和等于边际预测:
[11]:
np.abs(shap_interaction_values.sum((1, 2)) + explainer.expected_value - pred).max()
[11]:
0.0
并确保来自SHAP交互值的主要效果与线性模型中的效果相匹配:
[12]:
total = 0
for i in range(N):
for j in range(5):
total += np.abs(
shap_interaction_values[i, j, j] - main_effect_shap_values[i, j]
)
total
[12]:
1.0533118387982904e-11
解释一个带有一个交互项的线性模型
模拟一些二进制数据和一个带有交互项的线性结果。
请注意,我们使 X
中的特征完全相互独立,以便于求解精确的 SHAP 值。
[13]:
N = 2_000
X = np.zeros((N, 5))
X[:1_000, 0] = 1
X[:500, 1] = 1
X[1_000:1_500, 1] = 1
X[:250, 2] = 1
X[500:750, 2] = 1
X[1_000:1_250, 2] = 1
X[1_500:1_750, 2] = 1
X[:125, 3] = 1
X[250:375, 3] = 1
X[500:625, 3] = 1
X[750:875, 3] = 1
X[1_000:1_125, 3] = 1
X[1_250:1_375, 3] = 1
X[1_500:1_625, 3] = 1
X[1_750:1_875, 3] = 1
# we can't exactly mean center the data or XGBoost has trouble finding the splits
X[:, :4] -= 0.4999
# interaction of features is implemented as the multiplication of the features. Note that any other function of the
# features would also work, but is harder to interpret (e.g. sin(x1*x2)).
y = 2 * X[:, 0] - 3 * X[:, 1] + 2 * X[:, 1] * X[:, 2]
[14]:
X.mean(axis=0)
[14]:
array([1.e-04, 1.e-04, 1.e-04, 1.e-04, 0.e+00])
[15]:
# train a model with single tree
Xd = xgboost.DMatrix(X, label=y)
model = xgboost.train({"eta": 1, "max_depth": 4, "base_score": 0, "lambda": 0}, Xd, 1)
print("Model error =", np.linalg.norm(y - model.predict(Xd)))
print(model.get_dump(with_stats=True)[0])
Model error = 1.7365037830677591e-06
0:[f1<0.000100001693] yes=1,no=2,missing=1,gain=4499.3999,cover=2000
1:[f0<0.000100001693] yes=3,no=4,missing=3,gain=1000.00024,cover=1000
3:[f2<0.000100001693] yes=7,no=8,missing=7,gain=124.950005,cover=500
7:leaf=0.99970001,cover=250
8:leaf=-9.99800031e-05,cover=250
4:[f2<0.000100001693] yes=9,no=10,missing=9,gain=124.950195,cover=500
9:leaf=2.99970007,cover=250
10:leaf=1.99989998,cover=250
2:[f0<0.000100001693] yes=5,no=6,missing=5,gain=999.999756,cover=1000
5:[f2<0.000100001693] yes=11,no=12,missing=11,gain=125.050049,cover=500
11:leaf=-3.0000999,cover=250
12:leaf=-1.99989998,cover=250
6:[f2<0.000100001693] yes=13,no=14,missing=13,gain=125.050018,cover=500
13:leaf=-1.00010002,cover=250
14:leaf=0.000100019999,cover=250
SHAP 值
[16]:
pred = model.predict(Xd, output_margin=True)
explainer = shap.TreeExplainer(model)
explanation = explainer(Xd)
shap_values = explanation.values
# make sure the SHAP values add up to marginal predictions
np.abs(shap_values.sum(axis=1) + explanation.base_values - pred).max()
[16]:
4.7683716e-07
如果我们构建一个 beeswarm 图,我们会看到现在只有特征3和4无关紧要,并且由于交互作用,特征1可以有四种可能的效果大小。
[17]:
shap.plots.beeswarm(explanation)
训练一个线性模型
[18]:
lr = LinearRegression()
lr.fit(X, y)
lr_pred = lr.predict(X)
lr.coef_.round(2)
[18]:
array([ 2., -3., 0., -0., 0.])
请注意,SHAP 值不再与主效应匹配,因为它们现在包括交互效应。
[19]:
main_effect_shap_values = lr.coef_ * (X - X.mean(0))
np.linalg.norm(shap_values - main_effect_shap_values)
[19]:
15.811387829626835
SHAP 交互值
SHAP 交互贡献显示在非对角线上
[20]:
shap_interaction_values = explainer.shap_interaction_values(Xd)
shap_interaction_values[0].round(2)
[20]:
array([[ 1. , 0. , 0. , 0. , 0. ],
[ 0. , -1.5 , 0.25, 0. , 0. ],
[ 0. , 0.25, 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. ]], dtype=float32)
确保 SHAP 交互值的总和等于边际预测值
[21]:
np.abs(shap_interaction_values.sum((1, 2)) + explainer.expected_value - pred).max()
[21]:
4.7683716e-07
虽然当存在交互作用时,主要效应不再与SHAP值匹配,但它们确实与SHAP交互值矩阵对角线上的线性模型主要效应匹配。
[22]:
total = 0
for i in range(N):
for j in range(5):
total += np.abs(
shap_interaction_values[i, j, j] - main_effect_shap_values[i, j]
)
total
[22]:
0.0005347490392160024
[23]:
shap.dependence_plot(0, shap_values, X)
如果我们为特征0构建一个依赖图,我们会看到它只取两个值,并且这些值完全依赖于特征的值。因此,它们位于一条直线上(特征0的值完全决定了其效果,因为它与其他特征没有交互作用)。
相比之下,如果我们为特征2构建一个依赖图,我们会看到它有4个可能的值,并且它们不完全由特征2的值决定。相反,它们还依赖于特征3的值。依赖图中的这种垂直分布代表了非线性交互的影响。
[24]:
shap.dependence_plot(2, shap_values, X)
invalid value encountered in divide
invalid value encountered in divide