pip install ydf scikit-learn umap-learn plotly -U -q
Note: you may need to restart the kernel to use updated packages.
import ydf # Yggdrasil决策森林
import numpy as np
import pandas as pd # 我们使用Pandas加载小型数据集。
什么是邻居示例和反事实示例?¶
邻居示例是指在模型中相互类似的示例,即模型因相同原因对相同示例作出相同预测的示例。通过观察一个示例的邻居之间的相似性和差异性,可以很好地理解模型对该示例的预测。
反事实示例是指与感兴趣的示例标签不同的邻居示例。当模型的预测令人大吃一惊时,查看其反事实示例是理解原因的一个好方法。
什么是示例距离?¶
两个示例之间的距离是一个介于0和1之间的数值,表示两个示例之间的差异程度。感兴趣示例的邻居是指具有最小距离的示例。
决策森林模型定义了一种隐含的接近度或相似性测量,称为距离。该距离表示模型如何对待两个示例相似。非正式地说,如果两个示例属于同一类别并且原因相同,则它们是接近的。
这种距离对于理解模型及其预测非常有用。例如,我们可以利用它进行聚类、流形学习,或简单地查看最接近测试示例(称为反事实示例)的训练示例。这可以帮助我们理解模型为何作出其预测。
请记住,决策森林的距离度量只是数据集中许多合理的距离度量之一。其众多优点之一是能够比较具有不同尺度和不同语义的特征。
在这个笔记本中,我们将训练一个模型并使用其距离来:
找到与测试示例相邻的训练示例,并利用它们来解释模型的预测。
将所有示例映射到一个交互式二维图中(也称为二维流形),并自动检测表现相似的二维示例聚类。
应用层次聚类来解释模型作为一个整体是如何工作的。
更多信息: Leo Breiman,即随机森林学习算法的作者,提出了一种通过预训练的随机森林(RF)模型来测量两个示例之间的接近度的方法。他将这种方法称为"[...] 随机森林中最有用的工具之一."。在使用随机森林模型时,这是YDF使用的距离。
找到最接近测试样本的训练样本¶
让我们下载一个分类数据集。
ds_path = "https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset"
train_ds = pd.read_csv(f"{ds_path}/adult_train.csv")
test_ds = pd.read_csv(f"{ds_path}/adult_test.csv")
# 打印前5个训练样本
train_ds.head(5)
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 44 | Private | 228057 | 7th-8th | 4 | Married-civ-spouse | Machine-op-inspct | Wife | White | Female | 0 | 0 | 40 | Dominican-Republic | <=50K |
1 | 20 | Private | 299047 | Some-college | 10 | Never-married | Other-service | Not-in-family | White | Female | 0 | 0 | 20 | United-States | <=50K |
2 | 40 | Private | 342164 | HS-grad | 9 | Separated | Adm-clerical | Unmarried | White | Female | 0 | 0 | 37 | United-States | <=50K |
3 | 30 | Private | 361742 | Some-college | 10 | Married-civ-spouse | Exec-managerial | Husband | White | Male | 0 | 0 | 50 | United-States | <=50K |
4 | 67 | Self-emp-inc | 171564 | HS-grad | 9 | Married-civ-spouse | Prof-specialty | Wife | White | Female | 20051 | 0 | 30 | England | >50K |
我们在这个数据集上训练一个随机森林。
model = ydf.RandomForestLearner(label="income").train(train_ds)
Train model on 22792 examples Model trained in 0:00:00.922407
我们需要选择一个例子来解释。让我们选择测试数据集中第一个例子。
selected_example_idx = 0 # 切换以选择另一个示例
selected_example = test_ds[selected_example_idx:(selected_example_idx+1)]
selected_example
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 39 | State-gov | 77516 | Bachelors | 13 | Never-married | Adm-clerical | Not-in-family | White | Male | 2174 | 0 | 40 | United-States | <=50K |
在这个例子中,模型预测:
model.predict(selected_example)
array([0.01], dtype=float32)
换句话说,负类 <=50K
的概率为 $1-0.01=99\%$。
现在,我们计算所选测试示例与所有训练示例之间的距离。
distances = model.distance(train_ds, selected_example).squeeze()
print("distances:", distances)
distances: [1. 1. 1. ... 0.99333334 0.99666667 1. ]
让我们找到与我们选择的示例距离最小的五个训练示例。
close_train_idxs = np.argsort(distances)[:5]
print("close_train_idxs:", close_train_idxs)
print("Selected test examples:")
train_ds.iloc[close_train_idxs]
close_train_idxs: [16596 21845 10321 7299 14721] Selected test examples:
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
16596 | 41 | State-gov | 26892 | Bachelors | 13 | Never-married | Adm-clerical | Not-in-family | White | Male | 0 | 0 | 40 | United-States | <=50K |
21845 | 37 | State-gov | 60227 | Bachelors | 13 | Never-married | Adm-clerical | Not-in-family | White | Male | 0 | 0 | 38 | United-States | <=50K |
10321 | 40 | Private | 82161 | Bachelors | 13 | Never-married | Adm-clerical | Not-in-family | White | Male | 0 | 0 | 40 | United-States | <=50K |
7299 | 30 | State-gov | 158291 | Bachelors | 13 | Never-married | Adm-clerical | Not-in-family | White | Male | 0 | 0 | 40 | United-States | <=50K |
14721 | 32 | State-gov | 171111 | Bachelors | 13 | Never-married | Adm-clerical | Not-in-family | White | Male | 0 | 0 | 37 | United-States | <=50K |
观察:
- 对于所选的示例,模型预测的类别为
<=50K
。在五个最接近的示例中,模型有相同的预测结果。 - 这些最接近的示例共享许多特征值,例如
教育水平
、婚姻状况
、职业
、种族
,以及每周工作在37到40个小时之间。这很好地解释了为什么这些示例彼此接近。 - 这些示例的
年龄
范围在30到40岁之间,意味着模型将这一年龄范围视为这些示例的等值范围。
import plotly.graph_objs as go
from plotly.offline import iplot # 对于交互式图表
import plotly.io as pio
pio.renderers.default="colab"
# 所有测试样本之间的成对距离
distances = model.distance(test_ds, test_ds)
# 在二维流形上组织这些示例。
# 选择您要使用的方法。
# 使用不同的方法和不同的参数会改变投影。
manifold_lib = "UMAP" # "UMAP" or "TSNE"
if manifold_lib == "TSNE":
from sklearn.manifold import TSNE
manifold = TSNE(
# 显示的维度数量。也可以是三维。
n_components=2,
# 控制投影的形状。数值越高,创建的投影越多。
# 既有明显区分,又有一定程度合并的簇。数量范围在5到50之间。
perplexity=20,
metric="precomputed",
init="random",
verbose=1,
learning_rate="auto",
).fit_transform(distances)
elif manifold_lib == "UMAP":
import umap
manifold = umap.UMAP(
# 显示的维度数量。也可以是三维。
n_components=2,
# 在数据中平衡局部与全局结构。
n_neighbors=15,
metric="precomputed",
).fit_transform(distances)
else:
raise ValueError(f"Unknown lib: {manifold_lib}")
/usr/local/google/home/gbm/my_venv/lib/python3.11/site-packages/umap/umap_.py:1858: UserWarning: using precomputed metric; inverse_transform will be unavailable
让我们用示例特征创建一个交互式图表。
def example_to_html(example):
return "<br>".join([f"<b>{k}:</b> {v}" for k, v in example.items()])
def interactive_plot(dataset, projections):
colors = (dataset["income"] == ">50K").map(lambda x: ["red", "blue"][x])
labels = list(dataset.apply(example_to_html, axis=1).values)
args = {
"data": [
go.Scatter(
x=projections[:, 0],
y=projections[:, 1],
text=labels,
mode="markers",
marker={"color": colors, "size": 3},
)
],
"layout": go.Layout(width=500, height=500, template="simple_white"),
}
iplot(args)
interactive_plot(test_ds, manifold)
".join([f"{k}: {v}" for k, v in example.items()]) def interactive_plot(dataset, projections): colors = (dataset["income"] == ">50K").map(lambda x: ["red", "blue"][x]) labels = list(dataset.apply(example_to_html, axis=1).values) args = { "data": [ go.Scatter( x=projections[:, 0], y=projections[:, 1], text=labels, mode="markers", marker={"color": colors, "size": 3}, ) ], "layout": go.Layout(width=500, height=500, template="simple_white"), } iplot(args) interactive_plot(test_ds, manifold)
注意: 将鼠标移动到图形上以查看示例的值。
颜色代表标签。我们可以看到均匀颜色的簇(所有标签相同的簇)和混合颜色的簇(模型在做出良好预测方面遇到困难的簇)。
你能理解这些簇吗?
from sklearn.cluster import AgglomerativeClustering
num_clusters = 6
clustering = AgglomerativeClustering(
n_clusters=num_clusters,
metric="precomputed",
linkage="average",
).fit(distances)
接下来,我们打印特征的统计信息以及每个聚类中的一个示例。
import IPython
for cluster_idx in range(num_clusters):
selected_examples = test_ds[clustering.labels_ == cluster_idx]
print(f"Cluster #{cluster_idx} with {len(selected_examples)} examples")
print("=============================")
IPython.display.display(selected_examples.describe())
IPython.display.display(selected_examples.iloc[:1])
Cluster #0 with 2879 examples =============================
age | fnlwgt | education_num | capital_gain | capital_loss | hours_per_week | |
---|---|---|---|---|---|---|
count | 2879.000000 | 2879.000000 | 2879.000000 | 2879.000000 | 2879.000000 | 2879.000000 |
mean | 42.860021 | 184706.465439 | 8.879125 | 200.963876 | 32.425842 | 42.555054 |
std | 12.426582 | 99424.684674 | 1.929070 | 852.256462 | 231.362238 | 11.910265 |
min | 18.000000 | 19395.000000 | 1.000000 | 0.000000 | 0.000000 | 1.000000 |
25% | 33.000000 | 115465.500000 | 9.000000 | 0.000000 | 0.000000 | 40.000000 |
50% | 41.000000 | 176681.000000 | 9.000000 | 0.000000 | 0.000000 | 40.000000 |
75% | 51.000000 | 231872.500000 | 10.000000 | 0.000000 | 0.000000 | 46.000000 |
max | 90.000000 | 671292.000000 | 12.000000 | 5013.000000 | 2179.000000 | 99.000000 |
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
1 | 40 | Private | 121772 | Assoc-voc | 11 | Married-civ-spouse | Craft-repair | Husband | Asian-Pac-Islander | Male | 0 | 0 | 40 | NaN | >50K |
Cluster #1 with 5131 examples =============================
age | fnlwgt | education_num | capital_gain | capital_loss | hours_per_week | |
---|---|---|---|---|---|---|
count | 5131.000000 | 5.131000e+03 | 5131.000000 | 5131.000000 | 5131.000000 | 5131.000000 |
mean | 34.026895 | 1.931768e+05 | 9.726954 | 103.289222 | 57.424479 | 37.824401 |
std | 13.371512 | 1.055196e+05 | 2.395434 | 642.138022 | 328.764194 | 12.401540 |
min | 17.000000 | 1.921400e+04 | 1.000000 | 0.000000 | 0.000000 | 1.000000 |
25% | 23.000000 | 1.205865e+05 | 9.000000 | 0.000000 | 0.000000 | 35.000000 |
50% | 31.000000 | 1.817210e+05 | 10.000000 | 0.000000 | 0.000000 | 40.000000 |
75% | 42.000000 | 2.416855e+05 | 11.000000 | 0.000000 | 0.000000 | 40.000000 |
max | 90.000000 | 1.038553e+06 | 16.000000 | 7443.000000 | 3770.000000 | 99.000000 |
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 39 | State-gov | 77516 | Bachelors | 13 | Never-married | Adm-clerical | Not-in-family | White | Male | 2174 | 0 | 40 | United-States | <=50K |
Cluster #2 with 220 examples =============================
age | fnlwgt | education_num | capital_gain | capital_loss | hours_per_week | |
---|---|---|---|---|---|---|
count | 220.000000 | 220.000000 | 220.000000 | 220.0 | 220.000000 | 220.000000 |
mean | 44.863636 | 182932.690909 | 11.977273 | 0.0 | 1996.745455 | 46.700000 |
std | 11.372463 | 89132.990647 | 2.314227 | 0.0 | 174.160632 | 11.490357 |
min | 22.000000 | 20953.000000 | 9.000000 | 0.0 | 1825.000000 | 12.000000 |
25% | 36.750000 | 125575.250000 | 10.000000 | 0.0 | 1887.000000 | 40.000000 |
50% | 43.000000 | 169627.500000 | 13.000000 | 0.0 | 1902.000000 | 41.000000 |
75% | 51.000000 | 213384.500000 | 14.000000 | 0.0 | 1977.000000 | 50.000000 |
max | 83.000000 | 530099.000000 | 16.000000 | 0.0 | 2603.000000 | 99.000000 |
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
25 | 48 | Self-emp-not-inc | 191277 | Doctorate | 16 | Married-civ-spouse | Prof-specialty | Husband | White | Male | 0 | 1902 | 60 | United-States | >50K |
Cluster #3 with 1012 examples =============================
age | fnlwgt | education_num | capital_gain | capital_loss | hours_per_week | |
---|---|---|---|---|---|---|
count | 1012.000000 | 1.012000e+03 | 1012.000000 | 1012.000000 | 1012.000000 | 1012.000000 |
mean | 43.610672 | 1.862960e+05 | 13.541502 | 119.073123 | 14.341897 | 44.180830 |
std | 11.334174 | 1.074333e+05 | 0.874553 | 675.355176 | 152.606687 | 12.260441 |
min | 23.000000 | 2.232800e+04 | 13.000000 | 0.000000 | 0.000000 | 1.000000 |
25% | 35.000000 | 1.148158e+05 | 13.000000 | 0.000000 | 0.000000 | 40.000000 |
50% | 43.000000 | 1.756480e+05 | 13.000000 | 0.000000 | 0.000000 | 40.000000 |
75% | 50.000000 | 2.301228e+05 | 14.000000 | 0.000000 | 0.000000 | 50.000000 |
max | 90.000000 | 1.097453e+06 | 16.000000 | 5013.000000 | 1977.000000 | 99.000000 |
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
2 | 40 | Private | 193524 | Doctorate | 16 | Married-civ-spouse | Prof-specialty | Husband | White | Male | 0 | 0 | 60 | United-States | >50K |
Cluster #4 with 46 examples =============================
age | fnlwgt | education_num | capital_gain | capital_loss | hours_per_week | |
---|---|---|---|---|---|---|
count | 46.000000 | 46.000000 | 46.000000 | 46.000000 | 46.000000 | 46.000000 |
mean | 47.913043 | 171906.630435 | 15.543478 | 280.413043 | 252.260870 | 43.021739 |
std | 10.897148 | 81143.023865 | 0.503610 | 1088.008529 | 679.865949 | 16.043675 |
min | 32.000000 | 33155.000000 | 15.000000 | 0.000000 | 0.000000 | 6.000000 |
25% | 39.000000 | 115998.000000 | 15.000000 | 0.000000 | 0.000000 | 40.000000 |
50% | 48.500000 | 163298.000000 | 16.000000 | 0.000000 | 0.000000 | 40.000000 |
75% | 53.000000 | 211152.250000 | 16.000000 | 0.000000 | 0.000000 | 50.000000 |
max | 79.000000 | 345259.000000 | 16.000000 | 4787.000000 | 2824.000000 | 99.000000 |
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
618 | 36 | Private | 103110 | Doctorate | 16 | Never-married | Prof-specialty | Not-in-family | White | Male | 0 | 0 | 40 | England | <=50K |
Cluster #5 with 481 examples =============================
age | fnlwgt | education_num | capital_gain | capital_loss | hours_per_week | |
---|---|---|---|---|---|---|
count | 481.000000 | 481.000000 | 481.000000 | 481.000000 | 481.0 | 481.000000 |
mean | 45.621622 | 191274.322245 | 11.806653 | 19103.544699 | 0.0 | 46.636175 |
std | 11.010141 | 103664.004053 | 2.507912 | 25872.337100 | 0.0 | 11.647901 |
min | 20.000000 | 19302.000000 | 1.000000 | 5178.000000 | 0.0 | 2.000000 |
25% | 38.000000 | 119793.000000 | 10.000000 | 7298.000000 | 0.0 | 40.000000 |
50% | 44.000000 | 175232.000000 | 13.000000 | 10520.000000 | 0.0 | 45.000000 |
75% | 52.000000 | 235786.000000 | 14.000000 | 15024.000000 | 0.0 | 50.000000 |
max | 78.000000 | 617021.000000 | 16.000000 | 99999.000000 | 0.0 | 99.000000 |
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
21 | 44 | Private | 343591 | HS-grad | 9 | Divorced | Craft-repair | Not-in-family | White | Female | 14344 | 0 | 40 | United-States | >50K |