beeswarm
图
本笔记本旨在演示(并记录)如何使用 shap.plots.beeswarm
函数。它使用了一个在经典的 UCI 成人收入数据集上训练的 XGBoost 模型(这是一个分类任务,预测人们在 1990 年代的收入是否超过 $50k)。
[1]:
import xgboost
import shap
# train XGBoost model
X, y = shap.datasets.adult()
model = xgboost.XGBClassifier().fit(X, y)
# compute SHAP values
explainer = shap.Explainer(model, X)
shap_values = explainer(X)
98%|===================| 32071/32561 [00:58<00:00]
简单的蜂群摘要图
蜂群图旨在展示数据集中对模型输出影响最大的特征的密集信息摘要。每个解释实例由每个特征行上的一个点表示。点的x位置由该特征的SHAP值(shap_values.value[instance,feature]
)决定,点沿每个特征行“堆积”以显示密度。颜色用于显示特征的原始值(shap_values.data[instance,feature]
)。在下图中,我们可以看到年龄是平均上最重要的特征,而且年轻人(蓝色)不太可能赚到超过$50k。
[2]:
shap.plots.beeswarm(shap_values)
默认情况下,显示的最大特征数量是十个,但可以通过 max_display
参数进行调整:
[3]:
shap.plots.beeswarm(shap_values, max_display=20)
功能排序
默认情况下,特征是按照 shap_values.abs.mean(0)
排序的,这是每个特征的 SHAP 值的平均绝对值。然而,这种排序更强调广泛的平均影响,而对罕见但影响巨大的情况关注较少。如果我们想找到对个人影响较大的特征,我们可以改为按最大绝对值排序:
[4]:
shap.plots.beeswarm(shap_values, order=shap_values.abs.max(0))
有用的转换
有时在绘制SHAP值之前对其进行转换是有帮助的。下面我们绘制绝对值并将颜色固定为红色。这使得与标准的 shap_values.abs.mean(0)
条形图的对比更加丰富,因为条形图仅绘制了蜂群图中点的平均值。
[5]:
shap.plots.beeswarm(shap_values.abs, color="shap_red")
[6]:
shap.plots.bar(shap_values.abs.mean(0))
自定义颜色
默认情况下,beeswarm
使用 shap.plots.colors.red_blue
颜色映射,但你可以通过 color
参数传递任何 matplotlib 颜色或颜色映射:
[7]:
import matplotlib.pyplot as plt
shap.plots.beeswarm(shap_values, color=plt.get_cmap("cool"))
有更多有用示例的想法吗?我们鼓励提交增加此文档笔记本的拉取请求!