CLI快速入门
本页面解释了如何使用CLI API训练、评估、分析、生成预测以及测量二分类模型的推理速度。
一个端到端的示例可以在这里找到。
安装YDF CLI
1. 前往YDF GitHub的发布页面。
2. 下载适用于您操作系统的最新CLI版本。例如,要下载适用于Linux的CLI版本,请点击“cli_linux.zip”文件旁边的“下载”按钮。
3. 将ZIP文件解压到您选择的目录中,例如unzip cli_linux.zip
。
4. 打开终端窗口并导航到您解压ZIP文件的目录。
每个可执行文件(例如train
、evaluate
)执行不同的任务。例如,train
命令用于训练模型。
每个命令在命令页面中有解释,或者可以使用--help
标志查看:
下载数据集
在这个示例中,我们使用UCI Adult数据集。这是一个二分类数据集,目标是预测个人的收入是否超过$50,000。数据集中的特征包括数值型和类别型。
首先,我们从UCI机器学习库下载数据集的副本:
DATASET_SRC=https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset
wget -q ${DATASET_SRC}/adult_train.csv -O adult_train.csv
wget -q ${DATASET_SRC}/adult_test.csv -O adult_test.csv
训练数据集的前3个示例如下:
$ head -n 4 adult_train.csv
age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income
44,Private,228057,7th-8th,4,Married-civ-spouse,Machine-op-inspct,Wife,White,Female,0,0,40,Dominican-Republic,<=50K
20,Private,299047,Some-college,10,Never-married,Other-service,Not-in-family,White,Female,0,0,20,United-States,<=50K
40,Private,342164,HS-grad,9,Separated,Adm-clerical,Unmarried,White,Female,0,0,37,United-States,<=50K
数据集存储在两个CSV文件中,一个用于训练,一个用于测试。YDF可以直接加载CSV文件,这使得使用此数据集非常方便。
当将数据集路径传递给命令时,数据集的格式总是通过前缀指定。例如,路径csv:/path/to/my/file
中的前缀csv:
表示该文件是一个CSV文件。支持的数据集格式列表可以在这里找到。
创建数据规范
数据规范(简称数据集规范)是对数据集的描述。它包括可用列的列表、每列的语义(或类型)以及任何其他元数据,如字典或缺失值的比率。
可以使用infer_dataspec
命令自动计算数据规范并存储在数据规范文件中。
在训练模型之前查看数据规范是检测数据集中问题(如缺失值或错误数据类型)的好方法。
结果如下:
记录数: 22792
列数: 15
按类型划分的列数:
CATEGORICAL: 9 (60%)
NUMERICAL: 6 (40%)
列:
CATEGORICAL: 9 (60%)
3: "education" CATEGORICAL 有字典 词汇量:17 零个OOD项 最频繁项:"HS-grad" 7340 (32.2043%)
14: "income" CATEGORICAL 有字典 词汇量:3 零个OOD项 最频繁项:"<=50K" 17308 (75.9389%)
5: "marital_status" CATEGORICAL 有字典 词汇量:8 零个OOD项 最频繁项:"Married-civ-spouse" 10431 (45.7661%)
13: "native_country" CATEGORICAL 缺失值数:407 (1.78571%) 有字典 词汇量:41 OOD项数:1 (0.00446728%) 最频繁项:"United-States" 20436 (91.2933%)
6: "occupation" CATEGORICAL 缺失值数:1260 (5.52826%) 有字典 词汇量:14 OOD项数:1 (0.00464425%) 最频繁项:"Prof-specialty" 2870 (13.329%)
8: "race" CATEGORICAL 有字典 词汇量:6 零个OOD项 最频繁项:"White" 19467 (85.4115%)
7: "relationship" CATEGORICAL 有字典 词汇量:7 零个OOD项 最频繁项:"Husband" 9191 (40.3256%)
9: "sex" CATEGORICAL 有字典 词汇量:3 零个OOD项 最频繁项:"Male" 15165 (66.5365%)
1: "workclass" CATEGORICAL 缺失值数:1257 (5.51509%) 有字典 词汇量:8 OOD项数:1 (0.0046436%) 最频繁项:"Private" 15879 (73.7358%)
NUMERICAL: 6 (40%)
0: "age" NUMERICAL 均值:38.6153 最小值:17 最大值:90 标准差:13.661
10: "capital_gain" NUMERICAL 均值:1081.9 最小值:0 最大值:99999 标准差:7509.48
11: "capital_loss" NUMERICAL 均值:87.2806 最小值:0 最大值:4356 标准差:403.01
4: "education_num" NUMERICAL 均值:10.0927 最小值:1 最大值:16 标准差:2.56427
2: "fnlwgt" NUMERICAL 均值:189879 最小值:12285 最大值:1.4847e+06 标准差:106423
12: "hours_per_week" NUMERICAL 均值:40.3955 最小值:1 最大值:99 标准差:12.249
术语:
nas: 不可用(即缺失)值的数量。
ood: 字典外。
手动定义: 由用户手动定义类型的属性,即类型不是自动推断的。
分词: 属性值通过分词获得。
有字典: 属性附加有字符串字典,例如存储为字符串的分类属性。
词汇量: 唯一值的数量。
例如,education
列是一个分类列,具有17个唯一可能的值。最频繁的值是HS-grad
(占所有值的32%)。
(可选) 使用指南创建数据规范
在示例中,列的语义被正确检测。然而,当值的表示形式不明确时,情况可能并非如此。例如,枚举值(即以整数表示的分类值)无法在.csv文件中自动检测其语义。
在这种情况下,我们可以使用额外的标志重新运行infer_dataspec
命令,以指示错误检测列的真实语义。例如,要将age
强制检测为数值列,我们可以运行:
# 强制将'age'检测为数值列。
cat <<EOF > guide.pbtxt
column_guides {
column_name_pattern: "^age$"
type: NUMERICAL
}
EOF
./infer_dataspec --dataset=csv:adult_train.csv --guide=guide.pbtxt --output=dataspec.pbtxt
训练模型
模型通过train
命令进行训练。标签、特征、超参数和其他训练设置在训练配置文件中指定。
# 创建训练配置文件
cat <<EOF > train_config.pbtxt
task: CLASSIFICATION
label: "income"
learner: "GRADIENT_BOOSTED_TREES"
# 更改学习器的特定超参数。
[yggdrasil_decision_forests.model.gradient_boosted_trees.proto.gradient_boosted_trees_config] {
num_trees: 500
}
EOF
# 训练模型
./train \
--dataset=csv:adult_train.csv \
--dataspec=dataspec.pbtxt \
--config=train_config.pbtxt \
--output=model
结果:
[INFO train.cc:96] 开始训练模型。
[INFO abstract_learner.cc:119] 未指定输入特征。使用所有可用输入特征作为输入信号。
[INFO abstract_learner.cc:133] 标签 "income" 已从输入特征集中移除。
[INFO vertical_dataset_io.cc:74] 已扫描100个样本。
[INFO vertical_dataset_io.cc:80] 已读取22792个样本。内存:使用量:1MB,分配量:1MB。跳过了0(0%)个样本。
[INFO abstract_learner.cc:119] 未指定输入特征。使用所有可用输入特征作为输入信号。
[INFO abstract_learner.cc:133] 标签 "income" 已从输入特征集中移除。
[INFO gradient_boosted_trees.cc:405] 默认损失设置为 BINOMIAL_LOG_LIKELIHOOD
[INFO gradient_boosted_trees.cc:1008] 正在训练包含22792个样本和14个特征的梯度提升树。
[INFO gradient_boosted_trees.cc:1051] 20533个样本用于训练,2259个样本用于验证
[INFO gradient_boosted_trees.cc:1434] num-trees:1 train-loss:1.015975 train-accuracy:0.761895 valid-loss:1.071430 valid-accuracy:0.736609
[INFO gradient_boosted_trees.cc:1436] num-trees:2 train-loss:0.955303 train-accuracy:0.761895 valid-loss:1.007908 valid-accuracy:0.736609
[INFO gradient_boosted_trees.cc:2871] 由于验证损失不再减少,提前停止训练。最佳验证损失:0.579583
[INFO gradient_boosted_trees.cc:230] 将模型截断为136棵树,即136次迭代。
[INFO gradient_boosted_trees.cc:264] 最终模型 num-trees:136 valid-loss:0.579583 valid-accuracy:0.870297
几点说明:
-
由于未指定输入特征,除标签外的所有列都被用作输入特征。
-
DFs 原生支持数值、分类和分类集特征,以及缺失值。数值特征不需要归一化,分类字符串值也不需要字典编码。
-
除了
num_trees
超参数外,未指定其他训练超参数。所有超参数的默认值都设置为在大多数情况下提供合理的结果。我们将在后面讨论替代的默认值(称为超参数模板)和超参数的自动调优。所有超参数及其默认值的列表可在超参数页面中找到。 -
未为训练提供验证数据集。并非所有学习器都需要验证数据集。然而,本示例中使用的
GRADIENT_BOOSTED_TREES
学习器如果启用了早停(默认情况下启用),则需要验证数据集。在这种情况下,10% 的训练数据集用于验证。此比率可以通过validation_ratio
参数进行更改。或者,可以使用--valid_dataset
标志提供验证数据集。最终模型包含136棵树,验证准确率约为0.8702。
显示模型信息
使用 show_model
命令显示模型的详细信息。
结果示例:
类型: "GRADIENT_BOOSTED_TREES"
任务: 分类
标签: "income"
输入特征 (14):
age
workclass
fnlwgt
education
education_num
marital_status
occupation
relationship
race
sex
capital_gain
capital_loss
hours_per_week
native_country
无权重
变量重要性: MEAN_MIN_DEPTH:
1. "income" 4.868164 ################
2. "sex" 4.625136 #############
3. "race" 4.590606 #############
...
13. "occupation" 3.640103 ####
14. "marital_status" 3.626898 ###
15. "age" 3.219872
变量重要性: NUM_AS_ROOT:
1. "age" 28.000000 ################
2. "marital_status" 22.000000 ############
3. "capital_gain" 19.000000 ##########
...
11. "education_num" 3.000000
12. "occupation" 2.000000
13. "native_country" 2.000000
变量重要性: NUM_NODES:
1. "occupation" 516.000000 ################
2. "age" 431.000000 #############
3. "education" 424.000000 ############
...
12. "education_num" 73.000000 #
13. "sex" 39.000000
14. "race" 26.000000
变量重要性: SUM_SCORE:
1. "relationship" 3103.387636 ################
2. "capital_gain" 2041.557944 ##########
3. "education" 1090.544247 #####
...
12. "workclass" 176.876787
13. "sex" 49.287215
14. "race" 13.923084
损失: BINOMIAL_LOG_LIKELIHOOD
验证损失值: 0.579583
每迭代树的数量: 1
节点格式: BLOB_SEQUENCE
树的数量: 136
总节点数: 7384
每棵树的节点数:
计数: 136 平均值: 54.2941 标准差: 5.7779
最小值: 33 最大值: 63 忽略: 0
----------------------------------------------
[ 33, 34) 2 1.47% 1.47% #
...
[ 60, 62) 16 11.76% 96.32% ########
[ 62, 63] 5 3.68% 100.00% ##
叶节点的深度:
计数: 3760 平均值: 4.87739 标准差: 0.412078
最小值: 2 最大值: 5 忽略: 0
----------------------------------------------
[ 2, 3) 14 0.37% 0.37%
[ 3, 4) 75 1.99% 2.37%
[ 4, 5) 269 7.15% 9.52% #
[ 5, 5] 3402 90.48% 100.00% ##########
叶节点的训练观测数:
计数: 3760 平均值: 742.683 标准差: 2419.64
最小值: 5 最大值: 19713 忽略: 0
----------------------------------------------
[ 5, 990) 3270 86.97% 86.97% ##########
[ 990, 1975) 163 4.34% 91.30%
...
[ 17743, 18728) 10 0.27% 99.55%
[ 18728, 19713] 17 0.45% 100.00%
节点中的属性:
516 : occupation [CATEGORICAL]
431 : age [NUMERICAL]
424 : education [CATEGORICAL]
420 : fnlwgt [NUMERICAL]
297 : capital_gain [NUMERICAL]
291 : hours_per_week [NUMERICAL]
266 : capital_loss [NUMERICAL]
245 : native_country [CATEGORICAL]
224 : relationship [CATEGORICAL]
206 : workclass [CATEGORICAL]
166 : marital_status [CATEGORICAL]
73 : education_num [NUMERICAL]
39 : sex [CATEGORICAL]
26 : race [CATEGORICAL]
深度 <= 0 的节点中的属性:
28 : age [NUMERICAL]
22 : marital_status [CATEGORICAL]
19 : capital_gain [NUMERICAL]
12 : capital_loss [NUMERICAL]
11 : hours_per_week [NUMERICAL]
11 : fnlwgt [NUMERICAL]
8 : relationship [CATEGORICAL]
8 : education [CATEGORICAL]
6 : race [CATEGORICAL]
4 : sex [CATEGORICAL]
3 : education_num [NUMERICAL]
2 : native_country [CATEGORICAL]
2 : occupation [CATEGORICAL]
...
节点中的条件类型:
1844 : ContainsBitmapCondition
1778 : HigherCondition
2 : ContainsCondition
深度 <= 0 的节点中的条件类型:
84 : HigherCondition
52 : ContainsBitmapCondition
深度 <= 1 的节点中的条件类型:
243 : HigherCondition
165 : ContainsBitmapCondition
...
--full_definition
标志打印模型树的结构。
评估模型
使用 evaluate
命令计算并打印评估结果,结果可以以文本格式(--format=text
,默认)或带有图表的 HTML 格式(--format=html
)输出。
结果
评估:
预测数量(无权重): 9769
预测数量(有权重): 9769
任务: 分类
标签: 收入
准确率: 0.874399 CI95[W][0.86875 0.879882]
对数损失: 0.27768
错误率: 0.125601
默认准确率: 0.758727
默认对数损失: 0.552543
默认错误率: 0.241273
混淆矩阵:
truth\prediction
<OOD> <=50K >50K
<OOD> 0 0 0
<=50K 0 6971 441
>50K 0 786 1571
总计: 9769
一对一分类:
"<=50K" 对其他类别
auc: 0.929207 CI95[H][0.924358 0.934056] CI95[B][0.924076 0.934662]
p/r-auc: 0.975657 CI95[L][0.971891 0.97893] CI95[B][0.973397 0.977947]
ap: 0.975656 CI95[B][0.973393 0.977944]
">50K" 对其他类别
auc: 0.929207 CI95[H][0.921866 0.936549] CI95[B][0.923642 0.934566]
p/r-auc: 0.830708 CI95[L][0.815025 0.845313] CI95[B][0.817588 0.843956]
ap: 0.830674 CI95[B][0.817513 0.843892]
观察:
- 测试数据集包含 9769 个样本。
- 测试准确率为 0.874399,95% 置信区间为 [0.86875; 0.879882]。
- 测试 AUC 为 0.929207,95% 置信区间为 [0.924358 0.934056](使用闭式计算时)和 [0.973397 0.977947](使用自助法计算时)。
- PR-AUC 和 AP 指标也可用。
以下命令评估模型并将评估报告导出到 HTML 文件。
# 评估模型并将结果打印到 Html 文件中
./evaluate --dataset=csv:adult_test.csv --model=model --format=html > evaluation.html
生成预测
使用 predict
命令计算预测并导出到文件。
# 将模型的预测导出到 csv 文件
./predict --dataset=csv:adult_test.csv --model=model --output=csv:predictions.csv
# 显示前 3 个样本的预测
head -n 4 predictions.csv
结果:
基准模型速度
在时间敏感的应用中,模型的推理速度至关重要。benchmark_inference
命令测量模型的平均推理时间。
YDF 有多种算法来计算模型的预测。这些算法在速度和覆盖范围上有所不同。生成预测时,YDF 自动使用兼容的最快算法。
benchmark_inference
显示所有兼容算法的速度。
推理算法是单线程的,这意味着它们一次只能处理一个数据点。用户需要使用多线程来并行化推理。
结果:
批量大小 : 100 运行次数 : 20
每样本时间(微秒) 每批次时间(微秒) 方法
----------------------------------------
0.89 89 GradientBoostedTreesQuickScorerExtended [虚拟接口]
5.8475 584.75 GradientBoostedTreesGeneric [虚拟接口]
12.485 1248.5 通用慢速引擎
----------------------------------------
我们看到模型平均每样本运行时间为 0.89 微秒。