解释一个简单的OR函数
本笔记本研究了使用SHAP值解释OR函数的样子。通过这个例子,我们理解了改变背景分布如何影响你从 TreeExplainer 获得的解释。
它基于一个简单的例子,包含两个特征 is_young
和 is_female
,大致灵感来自于泰坦尼克号生存数据集,其中妇女和儿童在疏散时被优先考虑,因此更有可能生存。在这个模拟的例子中,这种效应被推向极端,所有儿童和妇女都生存,而没有成年男性生存。
[1]:
import numpy as np
import pandas as pd
import xgboost
from IPython.display import display
import shap
rng = np.random.default_rng(42)
按照OR函数创建一个数据集
[2]:
N = 40_000
M = 2
# randomly create binary features for `is_young` and `is_female`
X = (rng.standard_normal(size=(N, 2)) > 0) * 1
X = pd.DataFrame(X, columns=["is_young", "is_female"])
# force the first sample to be a young boy
X.loc[0, :] = [1, 0]
display(X.head(3))
# you survive (y=1) only if you are young or female
y = ((X.loc[:, "is_young"] + X.loc[:, "is_female"]).to_numpy() > 0) * 1
is_young | is_female | |
---|---|---|
0 | 1 | 0 |
1 | 1 | 1 |
2 | 0 | 0 |
训练一个XGBoost模型来模仿这个OR函数
[3]:
model = xgboost.XGBRegressor(n_estimators=100, learning_rate=0.1, random_state=3)
model.fit(X, y)
model.predict(X)
[3]:
array([9.9998671e-01, 9.9998671e-01, 1.3295135e-05, ..., 9.9998671e-01,
9.9998671e-01, 9.9998671e-01], dtype=float32)
解释对一个男孩的预测
使用训练集作为背景分布
请注意,在下面的示例解释中,is_young = True
具有正值(这意味着它增加了模型输出,从而增加了生存预测),而 is_female = False
具有负值(这意味着它减少了模型输出)。虽然有人可能会认为 is_female = False
应该没有影响,因为我们已经知道这个人是年轻的,但SHAP值考虑了特征的影响,即使我们不一定知道其他特征,这也是为什么 is_female = False
仍然对预测有负面影响。
[4]:
explainer = shap.TreeExplainer(model, X, feature_perturbation="interventional")
explanation = explainer(X.loc[[0], :])
# for the young boy:
expected_value = explanation.base_values[0]
shap_values = explanation.values[0]
print(f"explainer.expected_value: {expected_value:.4f}")
print(f"SHAP values for (is_young = True, is_female = False): {shap_values.round(4)}")
print("model output:", (expected_value + shap_values.sum()).round(4))
explainer.expected_value: 0.7600
SHAP values for (is_young = True, is_female = False): [ 0.385 -0.145]
model output: 1.0
同样的信息,但以瀑布图的形式可视化:
[5]:
# waterfall plot for the young boy (background distribution => training set)
shap.plots.waterfall(explanation[0])
仅使用负面示例作为背景分布
这个第二个解释示例的目的是演示如何使用不同的背景分布可以改变输入特征之间的信用分配。这是因为我们现在将一个特征的重要性与一个已故的人(成年男性)进行比较。年轻男孩与已故的人唯一不同的是男孩年轻,因此所有的信用都归于 is_young = True
特征。
这突显出,当使用一个定义良好的背景组时,解释通常会更清晰。在这种情况下,它将解释从“这个样本与典型的有何不同”转变为“这个样本与那些死亡的人有何不同”(换句话说,为什么你活下来了?)。
[6]:
explainer = shap.TreeExplainer(
model,
X.loc[y == 0, :], # background distribution => non-survival
feature_perturbation="interventional",
)
explanation = explainer(X.loc[[0], :])
# for the young boy:
expected_value = explanation.base_values[0]
shap_values = explanation.values[0]
print(f"explainer.expected_value: {expected_value:.4f}")
print(f"SHAP values for (is_young = True, is_female = False): {shap_values.round(4)}")
print("model output:", (expected_value + shap_values.sum()).round(4))
explainer.expected_value: 0.0000
SHAP values for (is_young = True, is_female = False): [1. 0.]
model output: 1.0
[7]:
# waterfall plot for the young boy (background distribution => non-survival)
shap.plots.waterfall(explanation[0])
仅使用背景分布的正例
我们也可以仅使用正例作为我们的背景分布,并且由于模型预期输出(在我们的背景分布下)与当前小男孩的输出之间的差异为零,SHAP值的总和也将为零。
[8]:
explainer = shap.TreeExplainer(
model,
X.loc[y == 1, :], # background distribution => survival
feature_perturbation="interventional",
)
explanation = explainer(X.loc[[0], :])
# for the young boy:
expected_value = explanation.base_values[0]
shap_values = explanation.values[0]
print(f"explainer.expected_value: {expected_value:.4f}")
print(f"SHAP values for (is_young = True, is_female = False): {shap_values.round(4)}")
print("model output:", (expected_value + shap_values.sum()).round(4))
explainer.expected_value: 1.0000
SHAP values for (is_young = True, is_female = False): [ 0.14 -0.14]
model output: 1.0
[9]:
# waterfall plot for the young boy (background distribution => survival)
shap.plots.waterfall(explanation[0])
使用年轻女性作为背景分布
如果我们把样本与年轻女性进行比较,那么除了成年男性(data=(0, 0)
)之外,这两个特征都不重要,在这种情况下,这两个特征都被赋予了相同的死亡信用(正如人们可能直观地预期的那样)。
[10]:
explainer = shap.TreeExplainer(
model,
np.ones((1, M)), # background distribution => all young women
feature_perturbation="interventional",
)
explanation = explainer(X.head(3))
print("Feature data:")
display(explanation.data)
print()
print("SHAP values:")
display(explanation.values.round(4))
Feature data:
array([[1, 0],
[1, 1],
[0, 0]])
SHAP values:
array([[ 0. , -0. ],
[ 0. , 0. ],
[-0.5, -0.5]])