使用 LightGBM 进行人口普查收入分类
本笔记本演示了如何使用 LightGBM 来预测个人年收入超过 50,000 美元的概率。它使用了标准的 UCI Adult 收入数据集。要下载此笔记本的副本,请访问 github。
像 LightGBM 这样的梯度提升机方法是这些类型的预测问题的最先进方法,适用于多模态的表格样式输入数据。Tree SHAP(arXiv 论文)允许为树集成方法精确计算 SHAP 值,并且已经直接集成到 C++ LightGBM 代码库中。这使得可以在不采样且不提供背景数据集的情况下快速精确计算 SHAP 值(因为背景是从树的覆盖范围推断出来的)。
这里我们演示如何使用SHAP值来理解LightGBM模型预测。
[1]:
import lightgbm as lgb
from sklearn.model_selection import train_test_split
import shap
# print the JS visualization code to the notebook
shap.initjs()
加载数据集
[2]:
X, y = shap.datasets.adult()
X_display, y_display = shap.datasets.adult(display=True)
# create a train/test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7)
d_train = lgb.Dataset(X_train, label=y_train)
d_test = lgb.Dataset(X_test, label=y_test)
训练模型
[3]:
params = {
"max_bin": 512,
"learning_rate": 0.05,
"boosting_type": "gbdt",
"objective": "binary",
"metric": "binary_logloss",
"num_leaves": 10,
"verbose": -1,
"min_data": 100,
"boost_from_average": True,
"early_stopping_round": 50,
}
model = lgb.train(
params,
d_train,
10000,
valid_sets=[d_test],
)
解释预测
在这里,我们使用集成到 Light GBM 中的 Tree SHAP 实现来解释整个数据集(32561 个样本)。Light GBM 的 Tree SHAP 实现通过 shap.TreeExplainer.shap_values
方法调用。
[4]:
explainer = shap.TreeExplainer(model)
shap_values = explainer(X)
[5]:
shap.force_plot(
explainer.expected_value, shap_values.values[1, :], X_display.iloc[0, :]
)
[5]:
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
可视化多个预测
为了使浏览器保持良好状态,我们只可视化1,000个个体。
[6]:
shap.force_plot(
explainer.expected_value, shap_values.values[:1000, :], X_display.iloc[:1000, :]
)
[6]:
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
SHAP 摘要图
我们没有使用典型的特征重要性条形图,而是使用每个特征的 SHAP 值的密度散点图来识别每个特征对验证数据集中个体模型输出的影响程度。特征按所有样本中 SHAP 值大小的总和排序。值得注意的是,关系特征对模型的总影响大于资本收益特征,但对于资本收益重要的样本,其影响大于年龄。换句话说,资本收益对少数预测的影响很大,而年龄对所有预测的影响较小。
请注意,当散点无法排成一行时,它们会堆积起来以显示密度,每个点的颜色代表该个体的特征值。
[7]:
shap.summary_plot(shap_values, X)
SHAP 依赖图
SHAP 依赖图显示了单个特征在整个数据集中的影响。它们绘制了特征值与该特征在多个样本中的 SHAP 值的关系。SHAP 依赖图类似于部分依赖图,但考虑了特征中存在的交互效应,并且仅在数据支持的输入空间区域中定义。SHAP 值在单个特征值处的垂直分散由交互效应驱动,并选择另一个特征进行着色以突出可能的交互。
[8]:
for name in X_train.columns:
shap.dependence_plot(name, shap_values.values, X, display_features=X_display)
训练一个每棵树只有两个叶子的模型,因此没有特征之间的交互项
强制模型不包含交互项意味着特征对结果的影响不依赖于任何其他特征的值。这在下面的SHAP依赖图中表现为没有垂直分布。垂直分布反映了单个特征值在不同上下文中对模型输出的影响可能不同,这取决于个体所具有的其他特征。然而,对于没有交互项的模型,特征的影响始终相同,无论个体可能具有哪些其他属性。
与传统的偏依赖图相比,SHAP依赖图的一个优势是能够区分有无交互项的模型。换句话说,SHAP依赖图通过散点图在给定特征值处的垂直方差,给出了交互项大小的概念。
[9]:
params = {
"max_bin": 512,
"learning_rate": 0.1,
"boosting_type": "gbdt",
"objective": "binary",
"metric": "binary_logloss",
"num_leaves": 2,
"verbose": -1,
"min_data": 100,
"boost_from_average": True,
"early_stopping_round": 50,
}
model_ind = lgb.train(
params,
d_train,
20000,
valid_sets=[d_test],
)
[10]:
explainer = shap.TreeExplainer(model_ind)
shap_values_ind = explainer(X)
请注意,下面的交互颜色条对本模型无意义,因为它没有交互。
[11]:
for name in X_train.columns:
shap.dependence_plot(name, shap_values_ind.values, X, display_features=X_display)