使用 scikit-learn 进行人口普查收入分类
此示例使用了来自UCI机器学习数据仓库的标准成人人口普查收入数据集。我们使用sci-kit learn训练了一个k近邻分类器,然后解释了预测结果。
[1]:
import sklearn
import shap
加载人口普查数据
[2]:
X, y = shap.datasets.adult()
X["Occupation"] *= 1000 # to show the impact of feature scale on KNN predictions
X_display, y_display = shap.datasets.adult(display=True)
X_train, X_valid, y_train, y_valid = sklearn.model_selection.train_test_split(
X, y, test_size=0.2, random_state=7
)
训练一个 k-近邻分类器
这里我们直接在数据上进行训练,没有任何归一化处理。
[4]:
knn = sklearn.neighbors.KNeighborsClassifier()
knn.fit(X_train, y_train)
[4]:
KNeighborsClassifier()
解释预测
通常我们会使用一个逻辑链接函数,以便让加性特征输入更好地映射到模型的概率输出空间,但由于knn可以产生无限的逻辑奇数比,因此在这个例子中我们不这样做。
值得注意的是,职业是我们在解释的1000个预测中的主导特征。这是因为它的数值变化比其他特征更大,因此它对k近邻计算的影响更大。
[5]:
def f(x):
return knn.predict_proba(x)[:, 1]
med = X_train.median().values.reshape((1, X_train.shape[1]))
explainer = shap.Explainer(f, med)
shap_values = explainer(X_valid.iloc[0:1000, :])
Permutation explainer: 1001it [00:25, 38.69it/s]
[5]:
shap.plots.waterfall(shap_values[0])
汇总蜂群图是查看整个数据集中所有特征相对影响的更好方法。特征按其在所有样本中SHAP值大小的总和排序。
[7]:
shap.plots.beeswarm(shap_values)
热图绘制提供了模型的另一种全局视图,这次重点关注人口子群体。
[8]:
shap.plots.heatmap(shap_values)
在训练模型之前对数据进行归一化
在这里,我们在标准化数据上重新训练了一个KNN模型。
[9]:
# normalize data
dtypes = list(zip(X.dtypes.index, map(str, X.dtypes)))
X_train_norm = X_train.copy()
X_valid_norm = X_valid.copy()
for k, dtype in dtypes:
m = X_train[k].mean()
s = X_train[k].std()
X_train_norm[k] -= m
X_train_norm[k] /= s
X_valid_norm[k] -= m
X_valid_norm[k] /= s
[10]:
knn_norm = sklearn.neighbors.KNeighborsClassifier()
knn_norm.fit(X_train_norm, y_train)
[10]:
KNeighborsClassifier()
解释预测
当我们解释新KNN模型的预测时,我们发现职业不再是主导特征,而是更具预测性的特征,如婚姻状况,驱动了大多数预测。这是一个简单的例子,说明了为什么解释模型为何做出预测可以揭示训练过程中的问题。
[11]:
def f(x):
return knn_norm.predict_proba(x)[:, 1]
med = X_train_norm.median().values.reshape((1, X_train_norm.shape[1]))
explainer = shap.Explainer(f, med)
shap_values_norm = explainer(X_valid_norm.iloc[0:1000, :])
Permutation explainer: 1001it [01:26, 11.55it/s]
通过一个总结图,我们可以看到婚姻状况平均来说是最重要的,但其他特征(如资本收益)对特定个体可能有更大的影响。
[12]:
shap.summary_plot(shap_values_norm, X_valid.iloc[0:1000, :])
依赖性散点图展示了受教育年限如何增加年收入超过5万美元的机会。
[14]:
shap.plots.scatter(shap_values_norm[:, "Education-Num"])
有更多有用示例的想法吗?我们鼓励提交增加此文档笔记本的拉取请求!