Skip to content

CLI快速入门

本页面解释了如何使用CLI API训练、评估、分析、生成预测以及测量二分类模型的推理速度。

一个端到端的示例可以在这里找到。

安装YDF CLI

1. 前往YDF GitHub的发布页面

2. 下载适用于您操作系统的最新CLI版本。例如,要下载适用于Linux的CLI版本,请点击“cli_linux.zip”文件旁边的“下载”按钮。

3. 将ZIP文件解压到您选择的目录中,例如unzip cli_linux.zip

4. 打开终端窗口并导航到您解压ZIP文件的目录。

每个可执行文件(例如trainevaluate)执行不同的任务。例如,train命令用于训练模型。

每个命令在命令页面中有解释,或者可以使用--help标志查看:

# 打印'train'命令的帮助信息
./train --help

下载数据集

在这个示例中,我们使用UCI Adult数据集。这是一个二分类数据集,目标是预测个人的收入是否超过$50,000。数据集中的特征包括数值型和类别型。

首先,我们从UCI机器学习库下载数据集的副本:

DATASET_SRC=https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset
wget -q ${DATASET_SRC}/adult_train.csv -O adult_train.csv
wget -q ${DATASET_SRC}/adult_test.csv -O adult_test.csv

训练数据集的前3个示例如下:

$ head -n 4 adult_train.csv

age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income
44,Private,228057,7th-8th,4,Married-civ-spouse,Machine-op-inspct,Wife,White,Female,0,0,40,Dominican-Republic,<=50K
20,Private,299047,Some-college,10,Never-married,Other-service,Not-in-family,White,Female,0,0,20,United-States,<=50K
40,Private,342164,HS-grad,9,Separated,Adm-clerical,Unmarried,White,Female,0,0,37,United-States,<=50K

数据集存储在两个CSV文件中,一个用于训练,一个用于测试。YDF可以直接加载CSV文件,这使得使用此数据集非常方便。

当将数据集路径传递给命令时,数据集的格式总是通过前缀指定。例如,路径csv:/path/to/my/file中的前缀csv:表示该文件是一个CSV文件。支持的数据集格式列表可以在这里找到。

创建数据规范

数据规范(简称数据集规范)是对数据集的描述。它包括可用列的列表、每列的语义(或类型)以及任何其他元数据,如字典或缺失值的比率。

可以使用infer_dataspec命令自动计算数据规范并存储在数据规范文件中。

# 创建数据规范
./infer_dataspec --dataset=csv:adult_train.csv --output=dataspec.pbtxt

在训练模型之前查看数据规范是检测数据集中问题(如缺失值或错误数据类型)的好方法。

# 显示数据规范
./show_dataspec --dataspec=dataspec.pbtxt

结果如下:

记录数: 22792
列数: 15

按类型划分的列数:
    CATEGORICAL: 9 (60%)
    NUMERICAL: 6 (40%)

列:

CATEGORICAL: 9 (60%)
    3: "education" CATEGORICAL 有字典 词汇量:17 零个OOD项 最频繁项:"HS-grad" 7340 (32.2043%)
    14: "income" CATEGORICAL 有字典 词汇量:3 零个OOD项 最频繁项:"<=50K" 17308 (75.9389%)
    5: "marital_status" CATEGORICAL 有字典 词汇量:8 零个OOD项 最频繁项:"Married-civ-spouse" 10431 (45.7661%)
    13: "native_country" CATEGORICAL 缺失值数:407 (1.78571%) 有字典 词汇量:41 OOD项数:1 (0.00446728%) 最频繁项:"United-States" 20436 (91.2933%)
    6: "occupation" CATEGORICAL 缺失值数:1260 (5.52826%) 有字典 词汇量:14 OOD项数:1 (0.00464425%) 最频繁项:"Prof-specialty" 2870 (13.329%)
    8: "race" CATEGORICAL 有字典 词汇量:6 零个OOD项 最频繁项:"White" 19467 (85.4115%)
    7: "relationship" CATEGORICAL 有字典 词汇量:7 零个OOD项 最频繁项:"Husband" 9191 (40.3256%)
    9: "sex" CATEGORICAL 有字典 词汇量:3 零个OOD项 最频繁项:"Male" 15165 (66.5365%)
    1: "workclass" CATEGORICAL 缺失值数:1257 (5.51509%) 有字典 词汇量:8 OOD项数:1 (0.0046436%) 最频繁项:"Private" 15879 (73.7358%)

NUMERICAL: 6 (40%)
    0: "age" NUMERICAL 均值:38.6153 最小值:17 最大值:90 标准差:13.661
    10: "capital_gain" NUMERICAL 均值:1081.9 最小值:0 最大值:99999 标准差:7509.48
    11: "capital_loss" NUMERICAL 均值:87.2806 最小值:0 最大值:4356 标准差:403.01
    4: "education_num" NUMERICAL 均值:10.0927 最小值:1 最大值:16 标准差:2.56427
    2: "fnlwgt" NUMERICAL 均值:189879 最小值:12285 最大值:1.4847e+06 标准差:106423
    12: "hours_per_week" NUMERICAL 均值:40.3955 最小值:1 最大值:99 标准差:12.249

术语:
    nas: 不可用(即缺失)值的数量。
    ood: 字典外。
    手动定义: 由用户手动定义类型的属性,即类型不是自动推断的。
    分词: 属性值通过分词获得。
    有字典: 属性附加有字符串字典,例如存储为字符串的分类属性。
    词汇量: 唯一值的数量。
这个示例数据集包含22,792个样本和15列。其中有9个分类列和6个数值列。列的语义指的是它所包含的数据类型。

例如,education列是一个分类列,具有17个唯一可能的值。最频繁的值是HS-grad(占所有值的32%)。

(可选) 使用指南创建数据规范

在示例中,列的语义被正确检测。然而,当值的表示形式不明确时,情况可能并非如此。例如,枚举值(即以整数表示的分类值)无法在.csv文件中自动检测其语义。

在这种情况下,我们可以使用额外的标志重新运行infer_dataspec命令,以指示错误检测列的真实语义。例如,要将age强制检测为数值列,我们可以运行:

# 强制将'age'检测为数值列。
cat <<EOF > guide.pbtxt
column_guides {
    column_name_pattern: "^age$"
    type: NUMERICAL
}
EOF

./infer_dataspec --dataset=csv:adult_train.csv --guide=guide.pbtxt --output=dataspec.pbtxt

训练模型

模型通过train命令进行训练。标签、特征、超参数和其他训练设置在训练配置文件中指定。

# 创建训练配置文件
cat <<EOF > train_config.pbtxt
task: CLASSIFICATION
label: "income"
learner: "GRADIENT_BOOSTED_TREES"
# 更改学习器的特定超参数。
[yggdrasil_decision_forests.model.gradient_boosted_trees.proto.gradient_boosted_trees_config] {
  num_trees: 500
}
EOF

# 训练模型
./train \
  --dataset=csv:adult_train.csv \
  --dataspec=dataspec.pbtxt \
  --config=train_config.pbtxt \
  --output=model

结果:

[INFO train.cc:96] 开始训练模型。
[INFO abstract_learner.cc:119] 未指定输入特征。使用所有可用输入特征作为输入信号。
[INFO abstract_learner.cc:133] 标签 "income" 已从输入特征集中移除。
[INFO vertical_dataset_io.cc:74] 已扫描100个样本。
[INFO vertical_dataset_io.cc:80] 已读取22792个样本。内存:使用量:1MB,分配量:1MB。跳过了0(0%)个样本。
[INFO abstract_learner.cc:119] 未指定输入特征。使用所有可用输入特征作为输入信号。
[INFO abstract_learner.cc:133] 标签 "income" 已从输入特征集中移除。
[INFO gradient_boosted_trees.cc:405] 默认损失设置为 BINOMIAL_LOG_LIKELIHOOD
[INFO gradient_boosted_trees.cc:1008] 正在训练包含22792个样本和14个特征的梯度提升树。
[INFO gradient_boosted_trees.cc:1051] 20533个样本用于训练,2259个样本用于验证
[INFO gradient_boosted_trees.cc:1434]   num-trees:1 train-loss:1.015975 train-accuracy:0.761895 valid-loss:1.071430 valid-accuracy:0.736609
[INFO gradient_boosted_trees.cc:1436]   num-trees:2 train-loss:0.955303 train-accuracy:0.761895 valid-loss:1.007908 valid-accuracy:0.736609
[INFO gradient_boosted_trees.cc:2871] 由于验证损失不再减少,提前停止训练。最佳验证损失:0.579583
[INFO gradient_boosted_trees.cc:230] 将模型截断为136棵树,即136次迭代。
[INFO gradient_boosted_trees.cc:264] 最终模型 num-trees:136 valid-loss:0.579583 valid-accuracy:0.870297

几点说明:

  • 由于未指定输入特征,除标签外的所有列都被用作输入特征。

  • DFs 原生支持数值、分类和分类集特征,以及缺失值。数值特征不需要归一化,分类字符串值也不需要字典编码。

  • 除了 num_trees 超参数外,未指定其他训练超参数。所有超参数的默认值都设置为在大多数情况下提供合理的结果。我们将在后面讨论替代的默认值(称为超参数模板)和超参数的自动调优。所有超参数及其默认值的列表可在超参数页面中找到。

  • 未为训练提供验证数据集。并非所有学习器都需要验证数据集。然而,本示例中使用的 GRADIENT_BOOSTED_TREES 学习器如果启用了早停(默认情况下启用),则需要验证数据集。在这种情况下,10% 的训练数据集用于验证。此比率可以通过 validation_ratio 参数进行更改。或者,可以使用 --valid_dataset 标志提供验证数据集。最终模型包含136棵树,验证准确率约为0.8702。

显示模型信息

使用 show_model 命令显示模型的详细信息。

# 显示模型信息。
./show_model --model=model

结果示例:

类型: "GRADIENT_BOOSTED_TREES"
任务: 分类
标签: "income"

输入特征 (14):
    age
    workclass
    fnlwgt
    education
    education_num
    marital_status
    occupation
    relationship
    race
    sex
    capital_gain
    capital_loss
    hours_per_week
    native_country

无权重

变量重要性: MEAN_MIN_DEPTH:
    1.         "income"  4.868164 ################
    2.            "sex"  4.625136 #############
    3.           "race"  4.590606 #############
    ...
   13.     "occupation"  3.640103 ####
   14. "marital_status"  3.626898 ###
   15.            "age"  3.219872

变量重要性: NUM_AS_ROOT:
    1.            "age" 28.000000 ################
    2. "marital_status" 22.000000 ############
    3.   "capital_gain" 19.000000 ##########
    ...
   11.  "education_num"  3.000000
   12.     "occupation"  2.000000
   13. "native_country"  2.000000

变量重要性: NUM_NODES:
    1.     "occupation" 516.000000 ################
    2.            "age" 431.000000 #############
    3.      "education" 424.000000 ############
    ...
   12.  "education_num" 73.000000 #
   13.            "sex" 39.000000
   14.           "race" 26.000000

变量重要性: SUM_SCORE:
    1.   "relationship" 3103.387636 ################
    2.   "capital_gain" 2041.557944 ##########
    3.      "education" 1090.544247 #####
    ...
   12.      "workclass" 176.876787
   13.            "sex" 49.287215
   14.           "race" 13.923084

损失: BINOMIAL_LOG_LIKELIHOOD
验证损失值: 0.579583
每迭代树的数量: 1
节点格式: BLOB_SEQUENCE
树的数量: 136
总节点数: 7384

每棵树的节点数:
计数: 136 平均值: 54.2941 标准差: 5.7779
最小值: 33 最大值: 63 忽略: 0
----------------------------------------------
[ 33, 34)  2   1.47%   1.47% #
...
[ 60, 62) 16  11.76%  96.32% ########
[ 62, 63]  5   3.68% 100.00% ##

叶节点的深度:
计数: 3760 平均值: 4.87739 标准差: 0.412078
最小值: 2 最大值: 5 忽略: 0
----------------------------------------------
[ 2, 3)   14   0.37%   0.37%
[ 3, 4)   75   1.99%   2.37%
[ 4, 5)  269   7.15%   9.52% #
[ 5, 5] 3402  90.48% 100.00% ##########

叶节点的训练观测数:
计数: 3760 平均值: 742.683 标准差: 2419.64
最小值: 5 最大值: 19713 忽略: 0
----------------------------------------------
[     5,   990) 3270  86.97%  86.97% ##########
[   990,  1975)  163   4.34%  91.30%
...
[ 17743, 18728)   10   0.27%  99.55%
[ 18728, 19713]   17   0.45% 100.00%

节点中的属性:
    516 : occupation [CATEGORICAL]
    431 : age [NUMERICAL]
    424 : education [CATEGORICAL]
    420 : fnlwgt [NUMERICAL]
    297 : capital_gain [NUMERICAL]
    291 : hours_per_week [NUMERICAL]
    266 : capital_loss [NUMERICAL]
    245 : native_country [CATEGORICAL]
    224 : relationship [CATEGORICAL]
    206 : workclass [CATEGORICAL]
    166 : marital_status [CATEGORICAL]
    73 : education_num [NUMERICAL]
    39 : sex [CATEGORICAL]
    26 : race [CATEGORICAL]

深度 <= 0 的节点中的属性:
    28 : age [NUMERICAL]
    22 : marital_status [CATEGORICAL]
    19 : capital_gain [NUMERICAL]
    12 : capital_loss [NUMERICAL]
    11 : hours_per_week [NUMERICAL]
    11 : fnlwgt [NUMERICAL]
    8 : relationship [CATEGORICAL]
    8 : education [CATEGORICAL]
    6 : race [CATEGORICAL]
    4 : sex [CATEGORICAL]
    3 : education_num [NUMERICAL]
    2 : native_country [CATEGORICAL]
    2 : occupation [CATEGORICAL]

...

节点中的条件类型:
    1844 : ContainsBitmapCondition
    1778 : HigherCondition
    2 : ContainsCondition
深度 <= 0 的节点中的条件类型:
    84 : HigherCondition
    52 : ContainsBitmapCondition
深度 <= 1 的节点中的条件类型:
    243 : HigherCondition
    165 : ContainsBitmapCondition
...
可以使用 --full_definition 标志打印模型树的结构。

评估模型

使用 evaluate 命令计算并打印评估结果,结果可以以文本格式(--format=text,默认)或带有图表的 HTML 格式(--format=html)输出。

# 评估模型并在控制台打印结果
./evaluate --dataset=csv:adult_test.csv --model=model

结果

评估:
预测数量(无权重): 9769
预测数量(有权重): 9769
任务: 分类
标签: 收入

准确率: 0.874399  CI95[W][0.86875 0.879882]
对数损失: 0.27768
错误率: 0.125601

默认准确率: 0.758727
默认对数损失: 0.552543
默认错误率: 0.241273

混淆矩阵:
truth\prediction
       <OOD>  <=50K  >50K
<OOD>      0      0     0
<=50K      0   6971   441
 >50K      0    786  1571
总计: 9769

一对一分类:
  "<=50K" 对其他类别
    auc: 0.929207  CI95[H][0.924358 0.934056] CI95[B][0.924076 0.934662]
    p/r-auc: 0.975657  CI95[L][0.971891 0.97893] CI95[B][0.973397 0.977947]
    ap: 0.975656   CI95[B][0.973393 0.977944]

  ">50K" 对其他类别
    auc: 0.929207  CI95[H][0.921866 0.936549] CI95[B][0.923642 0.934566]
    p/r-auc: 0.830708  CI95[L][0.815025 0.845313] CI95[B][0.817588 0.843956]
    ap: 0.830674   CI95[B][0.817513 0.843892]

观察:

  • 测试数据集包含 9769 个样本。
  • 测试准确率为 0.874399,95% 置信区间为 [0.86875; 0.879882]。
  • 测试 AUC 为 0.929207,95% 置信区间为 [0.924358 0.934056](使用闭式计算时)和 [0.973397 0.977947](使用自助法计算时)。
  • PR-AUC 和 AP 指标也可用。

以下命令评估模型并将评估报告导出到 HTML 文件。

# 评估模型并将结果打印到 Html 文件中
./evaluate --dataset=csv:adult_test.csv --model=model --format=html > evaluation.html

成人数据集上的评估图

生成预测

使用 predict 命令计算预测并导出到文件。

# 将模型的预测导出到 csv 文件
./predict --dataset=csv:adult_test.csv --model=model --output=csv:predictions.csv

# 显示前 3 个样本的预测
head -n 4 predictions.csv

结果:

<=50K,>50K
0.978384,0.0216162
0.641894,0.358106
0.180569,0.819431

基准模型速度

在时间敏感的应用中,模型的推理速度至关重要。benchmark_inference 命令测量模型的平均推理时间。

YDF 有多种算法来计算模型的预测。这些算法在速度和覆盖范围上有所不同。生成预测时,YDF 自动使用兼容的最快算法。

benchmark_inference 显示所有兼容算法的速度。

推理算法是单线程的,这意味着它们一次只能处理一个数据点。用户需要使用多线程来并行化推理。

# 基准测试模型的推理速度
./benchmark_inference --dataset=csv:adult_test.csv --model=model

结果:

批量大小 : 100  运行次数 : 20
每样本时间(微秒)  每批次时间(微秒)  方法
----------------------------------------
            0.89              89  GradientBoostedTreesQuickScorerExtended [虚拟接口]
          5.8475          584.75  GradientBoostedTreesGeneric [虚拟接口]
          12.485          1248.5  通用慢速引擎
----------------------------------------

我们看到模型平均每样本运行时间为 0.89 微秒。