🧭 开始使用¶
安装 YDF¶
pip install ydf -U
导入库¶
import ydf # Yggdrasil决策森林
import pandas as pd # 我们使用Pandas加载小型数据集。
下载和加载数据集¶
我们使用二分类成人数据集。目标是预测 income
列的值,该值可以是 <50k
或 >=50k
,使用其他数值和分类列。该数据集包含缺失值。
ds_path = "https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset"
# 下载并加载数据集为Pandas DataFrame
train_ds = pd.read_csv(f"{ds_path}/adult_train.csv")
test_ds = pd.read_csv(f"{ds_path}/adult_test.csv")
# 打印前5个训练样本
train_ds.head(5)
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 44 | Private | 228057 | 7th-8th | 4 | Married-civ-spouse | Machine-op-inspct | Wife | White | Female | 0 | 0 | 40 | Dominican-Republic | <=50K |
1 | 20 | Private | 299047 | Some-college | 10 | Never-married | Other-service | Not-in-family | White | Female | 0 | 0 | 20 | United-States | <=50K |
2 | 40 | Private | 342164 | HS-grad | 9 | Separated | Adm-clerical | Unmarried | White | Female | 0 | 0 | 37 | United-States | <=50K |
3 | 30 | Private | 361742 | Some-college | 10 | Married-civ-spouse | Exec-managerial | Husband | White | Male | 0 | 0 | 50 | United-States | <=50K |
4 | 67 | Self-emp-inc | 171564 | HS-grad | 9 | Married-civ-spouse | Prof-specialty | Wife | White | Female | 20051 | 0 | 30 | England | >50K |
训练模型¶
让我们使用所有超参数的默认值来训练一个梯度提升树模型。
model = ydf.GradientBoostedTreesLearner(label="income").train(train_ds)
Train model on 22792 examples Model trained in 0:00:03.698584
备注
- YDF 区分学习算法(即 学习者,如
GradientBoostedTreesLearner
)和 模型。稍后,在更高级的示例中,您将看到我们为什么这样做 :). - 学习者的唯一必需参数是
label
。其他参数具有良好的默认值。 - 我们没有指定输入特征,因此所有列都作为输入特征使用。特征的类型会自动检测(例如,数值、分类、布尔、文本,可能还有缺失值)并进行处理。
- 默认情况下,学习者训练分类模型。其他任务(例如,回归、排序、提升)可以通过任务参数进行配置,例如
task=ydf.Task.REGRESSION
。 - 训练日志可以在训练期间通过
verbose=2
参数显示,或在训练后通过model.describe()
查看。这对调试和理解训练过程很有用。 - 没有指定验证数据集。在这种情况下,像
GradientBoostedTreesLearner
这样的学习者将从培训数据集中提取可用于验证的数据。像RandomForestLearner
这样的其他学习者不需要验证数据集,并将使用所有数据进行训练。
查看模型¶
通过 model.describe()
,我们可以查看:
- 模型:模型任务、输入特征和大小。
- 数据规范:关于所有输入特征的统计类型。
- 训练:训练和验证的损失及指标。
- 调优(仅在启用超参数调优时):调优日志。
- 变量重要性:对模型最重要的特征。
- 结构:模型中的树。
model.describe()
Task : CLASSIFICATION
Label : income
Features (14) : age workclass fnlwgt education education_num marital_status occupation relationship race sex capital_gain capital_loss hours_per_week native_country
Weights : None
Trained with tuner : No
Model size : 2174 kB
Number of records: 22792 Number of columns: 15 Number of columns by type: CATEGORICAL: 9 (60%) NUMERICAL: 6 (40%) Columns: CATEGORICAL: 9 (60%) 0: "income" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"<=50K" 17308 (75.9389%) 2: "workclass" CATEGORICAL num-nas:1257 (5.51509%) has-dict vocab-size:8 num-oods:3 (0.0139308%) most-frequent:"Private" 15879 (73.7358%) 4: "education" CATEGORICAL has-dict vocab-size:17 zero-ood-items most-frequent:"HS-grad" 7340 (32.2043%) 6: "marital_status" CATEGORICAL has-dict vocab-size:8 zero-ood-items most-frequent:"Married-civ-spouse" 10431 (45.7661%) 7: "occupation" CATEGORICAL num-nas:1260 (5.52826%) has-dict vocab-size:14 num-oods:4 (0.018577%) most-frequent:"Prof-specialty" 2870 (13.329%) 8: "relationship" CATEGORICAL has-dict vocab-size:7 zero-ood-items most-frequent:"Husband" 9191 (40.3256%) 9: "race" CATEGORICAL has-dict vocab-size:6 zero-ood-items most-frequent:"White" 19467 (85.4115%) 10: "sex" CATEGORICAL has-dict vocab-size:3 zero-ood-items most-frequent:"Male" 15165 (66.5365%) 14: "native_country" CATEGORICAL num-nas:407 (1.78571%) has-dict vocab-size:41 num-oods:1 (0.00446728%) most-frequent:"United-States" 20436 (91.2933%) NUMERICAL: 6 (40%) 1: "age" NUMERICAL mean:38.6153 min:17 max:90 sd:13.661 3: "fnlwgt" NUMERICAL mean:189879 min:12285 max:1.4847e+06 sd:106423 5: "education_num" NUMERICAL mean:10.0927 min:1 max:16 sd:2.56427 11: "capital_gain" NUMERICAL mean:1081.9 min:0 max:99999 sd:7509.48 12: "capital_loss" NUMERICAL mean:87.2806 min:0 max:4356 sd:403.01 13: "hours_per_week" NUMERICAL mean:40.3955 min:1 max:99 sd:12.249 Terminology: nas: Number of non-available (i.e. missing) values. ood: Out of dictionary. manually-defined: Attribute whose type is manually defined by the user, i.e., the type was not automatically inferred. tokenized: The attribute value is obtained through tokenization. has-dict: The attribute is attached to a string dictionary e.g. a categorical attribute stored as a string. vocab-size: Number of unique values.
The following evaluation is computed on the validation or out-of-bag dataset.
Task: CLASSIFICATION Label: income Loss (BINOMIAL_LOG_LIKELIHOOD): 0.576162 Accuracy: 0.868526 CI95[W][0 1] ErrorRate: : 0.131474 Confusion Table: truth\prediction <=50K >50K <=50K 1557 107 >50K 190 405 Total: 2259
Variable importances measure the importance of an input feature for a model.
1. "age" 0.226642 ################ 2. "occupation" 0.219727 ############# 3. "capital_gain" 0.214876 ############ 4. "education" 0.213746 ########### 5. "marital_status" 0.212739 ########### 6. "relationship" 0.206040 ######### 7. "fnlwgt" 0.203843 ######## 8. "hours_per_week" 0.203735 ######## 9. "capital_loss" 0.196549 ###### 10. "native_country" 0.190548 #### 11. "workclass" 0.187795 ### 12. "education_num" 0.184215 ## 13. "race" 0.180495 14. "sex" 0.177647
1. "age" 26.000000 ################ 2. "capital_gain" 26.000000 ################ 3. "marital_status" 20.000000 ############ 4. "relationship" 17.000000 ########## 5. "capital_loss" 14.000000 ######## 6. "hours_per_week" 14.000000 ######## 7. "education" 12.000000 ####### 8. "fnlwgt" 10.000000 ##### 9. "race" 9.000000 ##### 10. "education_num" 7.000000 ### 11. "sex" 4.000000 # 12. "occupation" 2.000000 13. "workclass" 1.000000 14. "native_country" 1.000000
1. "occupation" 724.000000 ################ 2. "fnlwgt" 513.000000 ########### 3. "age" 483.000000 ########## 4. "education" 464.000000 ########## 5. "hours_per_week" 339.000000 ####### 6. "capital_gain" 326.000000 ###### 7. "native_country" 306.000000 ###### 8. "capital_loss" 297.000000 ###### 9. "relationship" 262.000000 ##### 10. "workclass" 244.000000 ##### 11. "marital_status" 210.000000 #### 12. "education_num" 82.000000 # 13. "sex" 42.000000 14. "race" 21.000000
1. "relationship" 3014.690076 ################ 2. "capital_gain" 2065.521668 ########## 3. "education" 1144.490954 ###### 4. "marital_status" 1111.389695 ##### 5. "occupation" 1094.619502 ##### 6. "education_num" 796.666823 #### 7. "capital_loss" 584.055066 ### 8. "age" 582.288569 ### 9. "hours_per_week" 366.856509 # 10. "native_country" 263.872689 # 11. "fnlwgt" 216.537764 # 12. "workclass" 196.085503 # 13. "sex" 47.217730 14. "race" 5.428727
Those variable importances are computed during training. More, and possibly more informative, variable importances are available when analyzing a model on a test dataset.
Only printing the first tree.
Tree #0: "relationship" is in [BITMAP] {<OOD>, Husband, Wife} [s:0.036623 n:20533 np:9213 miss:1] ; pred:-4.15883e-09 ├─(pos)─ "education_num">=12.5 [s:0.0343752 n:9213 np:2773 miss:0] ; pred:0.116933 | ├─(pos)─ "capital_gain">=5095.5 [s:0.0125728 n:2773 np:434 miss:0] ; pred:0.272683 | | ├─(pos)─ "occupation" is in [BITMAP] {<OOD>, Prof-specialty, Exec-managerial, Craft-repair, Adm-clerical, Sales, Other-service, Machine-op-inspct, Transport-moving, Handlers-cleaners, ...[2 left]} [s:0.000434532 n:434 np:429 miss:1] ; pred:0.416173 | | | ├─(pos)─ "age">=79.5 [s:0.000449964 n:429 np:5 miss:0] ; pred:0.417414 | | | | ├─(pos)─ pred:0.309737 | | | | └─(neg)─ pred:0.418684 | | | └─(neg)─ pred:0.309737 | | └─(neg)─ "capital_loss">=1782.5 [s:0.0101181 n:2339 np:249 miss:0] ; pred:0.246058 | | ├─(pos)─ "capital_loss">=1989.5 [s:0.00201289 n:249 np:39 miss:0] ; pred:0.406701 | | | ├─(pos)─ pred:0.349312 | | | └─(neg)─ pred:0.417359 | | └─(neg)─ "occupation" is in [BITMAP] {Prof-specialty, Exec-managerial, Sales, Tech-support, Protective-serv} [s:0.0097175 n:2090 np:1688 miss:0] ; pred:0.226919 | | ├─(pos)─ pred:0.253437 | | └─(neg)─ pred:0.11557 | └─(neg)─ "capital_gain">=5095.5 [s:0.0205419 n:6440 np:303 miss:0] ; pred:0.0498685 | ├─(pos)─ "age">=60.5 [s:0.00421502 n:303 np:43 miss:0] ; pred:0.40543 | | ├─(pos)─ "occupation" is in [BITMAP] {Prof-specialty, Exec-managerial, Adm-clerical, Sales, Machine-op-inspct, Transport-moving, Handlers-cleaners} [s:0.0296244 n:43 np:25 miss:0] ; pred:0.317428 | | | ├─(pos)─ pred:0.397934 | | | └─(neg)─ pred:0.205614 | | └─(neg)─ "fnlwgt">=36212.5 [s:1.36643e-16 n:260 np:250 miss:1] ; pred:0.419984 | | ├─(pos)─ pred:0.419984 | | └─(neg)─ pred:0.419984 | └─(neg)─ "occupation" is in [BITMAP] {Prof-specialty, Exec-managerial, Adm-clerical, Sales, Tech-support, Protective-serv} [s:0.0100346 n:6137 np:2334 miss:0] ; pred:0.0323136 | ├─(pos)─ "age">=33.5 [s:0.00939348 n:2334 np:1769 miss:1] ; pred:0.102799 | | ├─(pos)─ pred:0.132992 | | └─(neg)─ pred:0.00826457 | └─(neg)─ "education" is in [BITMAP] {<OOD>, HS-grad, Some-college, Bachelors, Masters, Assoc-voc, Assoc-acdm, Prof-school, Doctorate} [s:0.00478423 n:3803 np:2941 miss:1] ; pred:-0.0109452 | ├─(pos)─ pred:0.00969668 | └─(neg)─ pred:-0.0813718 └─(neg)─ "capital_gain">=7073.5 [s:0.0143125 n:11320 np:199 miss:0] ; pred:-0.0951681 ├─(pos)─ "age">=21.5 [s:0.00807667 n:199 np:194 miss:1] ; pred:0.397823 | ├─(pos)─ "capital_gain">=7565.5 [s:0.00761118 n:194 np:184 miss:0] ; pred:0.405777 | | ├─(pos)─ "capital_gain">=30961.5 [s:0.000242202 n:184 np:20 miss:0] ; pred:0.416988 | | | ├─(pos)─ pred:0.392422 | | | └─(neg)─ pred:0.419984 | | └─(neg)─ "education" is in [BITMAP] {Bachelors, Masters, Assoc-voc, Prof-school} [s:0.16 n:10 np:5 miss:0] ; pred:0.19949 | | ├─(pos)─ pred:0.419984 | | └─(neg)─ pred:-0.0210046 | └─(neg)─ pred:0.0892425 └─(neg)─ "education" is in [BITMAP] {<OOD>, Bachelors, Masters, Prof-school, Doctorate} [s:0.00229611 n:11121 np:2199 miss:1] ; pred:-0.10399 ├─(pos)─ "age">=31.5 [s:0.00725859 n:2199 np:1263 miss:1] ; pred:-0.0507848 | ├─(pos)─ "education" is in [BITMAP] {<OOD>, HS-grad, Some-college, Assoc-voc, 11th, Assoc-acdm, 10th, 7th-8th, Prof-school, 9th, ...[5 left]} [s:0.0110157 n:1263 np:125 miss:1] ; pred:-0.0103552 | | ├─(pos)─ pred:0.16421 | | └─(neg)─ pred:-0.0295298 | └─(neg)─ "capital_loss">=1977 [s:0.00164232 n:936 np:5 miss:0] ; pred:-0.105339 | ├─(pos)─ pred:0.19949 | └─(neg)─ pred:-0.106976 └─(neg)─ "capital_loss">=2218.5 [s:0.000534265 n:8922 np:41 miss:0] ; pred:-0.117103 ├─(pos)─ "fnlwgt">=125450 [s:0.0755454 n:41 np:28 miss:1] ; pred:0.0704198 | ├─(pos)─ pred:-0.0328167 | └─(neg)─ pred:0.292776 └─(neg)─ "hours_per_week">=40.5 [s:0.000447024 n:8881 np:1559 miss:0] ; pred:-0.117969 ├─(pos)─ pred:-0.0927111 └─(neg)─ pred:-0.123347
进行预测¶
model.predict(ds)
应用模型并返回预测结果作为一个 Numpy 数组。
model.predict(test_ds)
array([0.01860435, 0.36130956, 0.83858865, ..., 0.03087652, 0.08280362, 0.00970956], dtype=float32)
可以消耗数据集的方法,例如train
和predict
,支持多种数据集格式,如Pandas DataFrames、列表或Numpy数组的字典、TensorFlow数据集以及事件文件路径!
# 使用字典进行预测
model.predict({
'age': [39],
'workclass': ['State-gov'],
'fnlwgt': [77516],
'education': ['Bachelors'],
'education_num': [13],
'marital_status': ['Never-married'],
'occupation': ['Adm-clerical'],
'relationship': ['Not-in-family'],
'race': ['White'],
'sex': ['Male'],
'capital_gain': [2174],
'capital_loss': [0],
'hours_per_week': [40],
'native_country': ['United-States'],
'income': ['<=50K'],
})
array([0.01860435], dtype=float32)
评估模型¶
虽然上面的验证数据集提供了模型质量的指示,但我们也希望在测试数据集上评估模型。
evaluation = model.evaluate(test_ds)
# 查询个人评估指标
print(f"test accuracy: {evaluation.accuracy}")
# 显示完整评估报告
print("Full evaluation report:")
evaluation
test accuracy: 0.8738867847271983 Full evaluation report:
Label \ Pred | <=50K | >50K |
---|---|---|
<=50K | 6962 | 782 |
>50K | 450 | 1575 |
model.analyze(test_ds, sampling=0.1)
Variable importances measure the importance of an input feature for a model.
1. "capital_gain" 0.052513 ################ 2. "occupation" 0.020882 ###### 3. "age" 0.015559 #### 4. "relationship" 0.015150 #### 5. "marital_status" 0.014331 #### 6. "capital_loss" 0.014331 #### 7. "education" 0.009110 ## 8. "hours_per_week" 0.006551 # 9. "education_num" 0.005323 # 10. "workclass" 0.003378 11. "race" 0.001024 12. "sex" 0.000921 13. "fnlwgt" 0.000614 14. "native_country" 0.000614
1. "capital_gain" 0.248326 ################ 2. "age" 0.051386 ### 3. "marital_status" 0.046224 ## 4. "capital_loss" 0.044403 ## 5. "occupation" 0.037985 ## 6. "relationship" 0.037500 ## 7. "education" 0.021677 # 8. "hours_per_week" 0.015487 9. "education_num" 0.008588 10. "workclass" 0.003808 11. "sex" 0.003478 12. "fnlwgt" 0.002788 13. "native_country" 0.001978 14. "race" 0.001111
1. "capital_gain" 0.061589 ################ 2. "age" 0.033311 ######## 3. "marital_status" 0.029546 ####### 4. "relationship" 0.020694 ##### 5. "occupation" 0.019686 ##### 6. "capital_loss" 0.014316 ### 7. "education" 0.012061 ## 8. "hours_per_week" 0.009984 ## 9. "education_num" 0.004140 10. "sex" 0.001985 11. "workclass" 0.001577 12. "native_country" 0.001397 13. "fnlwgt" 0.000936 14. "race" 0.000637
1. "capital_gain" 0.248064 ################ 2. "age" 0.051338 ### 3. "marital_status" 0.045982 ## 4. "capital_loss" 0.044387 ## 5. "occupation" 0.037982 ## 6. "relationship" 0.037494 ## 7. "education" 0.021676 # 8. "hours_per_week" 0.015486 9. "education_num" 0.008585 10. "workclass" 0.003812 11. "sex" 0.003477 12. "fnlwgt" 0.002791 13. "native_country" 0.001981 14. "race" 0.001112
1. "age" 0.226642 ################ 2. "occupation" 0.219727 ############# 3. "capital_gain" 0.214876 ############ 4. "education" 0.213746 ########### 5. "marital_status" 0.212739 ########### 6. "relationship" 0.206040 ######### 7. "fnlwgt" 0.203843 ######## 8. "hours_per_week" 0.203735 ######## 9. "capital_loss" 0.196549 ###### 10. "native_country" 0.190548 #### 11. "workclass" 0.187795 ### 12. "education_num" 0.184215 ## 13. "race" 0.180495 14. "sex" 0.177647
1. "age" 26.000000 ################ 2. "capital_gain" 26.000000 ################ 3. "marital_status" 20.000000 ############ 4. "relationship" 17.000000 ########## 5. "capital_loss" 14.000000 ######## 6. "hours_per_week" 14.000000 ######## 7. "education" 12.000000 ####### 8. "fnlwgt" 10.000000 ##### 9. "race" 9.000000 ##### 10. "education_num" 7.000000 ### 11. "sex" 4.000000 # 12. "occupation" 2.000000 13. "workclass" 1.000000 14. "native_country" 1.000000
1. "occupation" 724.000000 ################ 2. "fnlwgt" 513.000000 ########### 3. "age" 483.000000 ########## 4. "education" 464.000000 ########## 5. "hours_per_week" 339.000000 ####### 6. "capital_gain" 326.000000 ###### 7. "native_country" 306.000000 ###### 8. "capital_loss" 297.000000 ###### 9. "relationship" 262.000000 ##### 10. "workclass" 244.000000 ##### 11. "marital_status" 210.000000 #### 12. "education_num" 82.000000 # 13. "sex" 42.000000 14. "race" 21.000000
1. "relationship" 3014.690076 ################ 2. "capital_gain" 2065.521668 ########## 3. "education" 1144.490954 ###### 4. "marital_status" 1111.389695 ##### 5. "occupation" 1094.619502 ##### 6. "education_num" 796.666823 #### 7. "capital_loss" 584.055066 ### 8. "age" 582.288569 ### 9. "hours_per_week" 366.856509 # 10. "native_country" 263.872689 # 11. "fnlwgt" 216.537764 # 12. "workclass" 196.085503 # 13. "sex" 47.217730 14. "race" 5.428727
基准模型速度¶
在模型速度至关重要的应用中,我们可以使用 model.benchmark(ds)
来评估模型的速度。
model.benchmark(test_ds)
Inference time per example and per cpu core: 0.891 us (microseconds) Estimated over 345 runs over 3.004 seconds. * Measured with the C++ serving API. Check model.to_cpp() for details.
基准测试测量了使用C++ API时模型的速度。由于Python解释器的开销,Python API会更慢。如果您不熟悉C++ API,可以使用model.to_cpp()
方法生成可以运行的C++代码,以评估模型的速度。
print(model.to_cpp())
// Automatically generated code running an Yggdrasil Decision Forests model in // C++. This code was generated with "model.to_cpp()". // // Date of generation: 2023-12-19 15:29:09.343331 // YDF Version: 0.0.8 // // How to use this code: // // 1. Copy this code in a new .h file. // 2. If you use Bazel/Blaze, use the following dependencies: // //third_party/absl/status:statusor // //third_party/absl/strings // //external/ydf_cc/yggdrasil_decision_forests/api:serving // 3. In your existing code, include the .h file and do: // // Load the model (to do only once). // namespace ydf = yggdrasil_decision_forests; // const auto model = ydf::exported_model_123::Load(<path to model>); // // Run the model // predictions = model.Predict(); // 4. By default, the "Predict" function takes no inputs and creates fake // examples. In practice, you want to add your input data as arguments to // "Predict" and call "examples->Set..." functions accordingly. // 4. (Bonus) // Allocate one "examples" and "predictions" per thread and reuse them to // speed-up the inference. // #ifndef YGGDRASIL_DECISION_FORESTS_GENERATED_MODEL_my_model #define YGGDRASIL_DECISION_FORESTS_GENERATED_MODEL_my_model #include <memory> #include <vector> #include "third_party/absl/status/statusor.h" #include "third_party/absl/strings/string_view.h" #include "external/ydf_cc/yggdrasil_decision_forests/api/serving.h" namespace yggdrasil_decision_forests { namespace exported_model_my_model { struct ServingModel { std::vector<float> Predict() const; // Compiled model. std::unique_ptr<serving_api::FastEngine> engine; // Index of the input features of the model. // // Non-owning pointer. The data is owned by the engine. const serving_api::FeaturesDefinition* features; // Number of output predictions for each example. // Equal to 1 for regression, ranking and binary classification with compact // format. Equal to the number of classes for classification. int NumPredictionDimension() const { return engine->NumPredictionDimension(); } // Indexes of the input features. serving_api::NumericalFeatureId feature_age; serving_api::CategoricalFeatureId feature_workclass; serving_api::NumericalFeatureId feature_fnlwgt; serving_api::CategoricalFeatureId feature_education; serving_api::NumericalFeatureId feature_education_num; serving_api::CategoricalFeatureId feature_marital_status; serving_api::CategoricalFeatureId feature_occupation; serving_api::CategoricalFeatureId feature_relationship; serving_api::CategoricalFeatureId feature_race; serving_api::CategoricalFeatureId feature_sex; serving_api::NumericalFeatureId feature_capital_gain; serving_api::NumericalFeatureId feature_capital_loss; serving_api::NumericalFeatureId feature_hours_per_week; serving_api::CategoricalFeatureId feature_native_country; }; // TODO: Pass input feature values to "Predict". inline std::vector<float> ServingModel::Predict() const { // Allocate memory for 2 examples. Alternatively, for speed-sensitive code, // an "examples" object can be allocated for each thread and reused. It is // okay to allocate more examples than needed. const int num_examples = 2; auto examples = engine->AllocateExamples(num_examples); // Set all the values to be missing. The values may then be overridden by the // "Set*" methods. If all the values are set with "Set*" methods, // "FillMissing" can be skipped. examples->FillMissing(*features); // Example #0 examples->SetNumerical(/*example_idx=*/0, feature_age, 1.f, *features); examples->SetCategorical(/*example_idx=*/0, feature_workclass, "A", *features); examples->SetNumerical(/*example_idx=*/0, feature_fnlwgt, 1.f, *features); examples->SetCategorical(/*example_idx=*/0, feature_education, "A", *features); examples->SetNumerical(/*example_idx=*/0, feature_education_num, 1.f, *features); examples->SetCategorical(/*example_idx=*/0, feature_marital_status, "A", *features); examples->SetCategorical(/*example_idx=*/0, feature_occupation, "A", *features); examples->SetCategorical(/*example_idx=*/0, feature_relationship, "A", *features); examples->SetCategorical(/*example_idx=*/0, feature_race, "A", *features); examples->SetCategorical(/*example_idx=*/0, feature_sex, "A", *features); examples->SetNumerical(/*example_idx=*/0, feature_capital_gain, 1.f, *features); examples->SetNumerical(/*example_idx=*/0, feature_capital_loss, 1.f, *features); examples->SetNumerical(/*example_idx=*/0, feature_hours_per_week, 1.f, *features); examples->SetCategorical(/*example_idx=*/0, feature_native_country, "A", *features); // Example #1 examples->SetNumerical(/*example_idx=*/1, feature_age, 2.f, *features); examples->SetCategorical(/*example_idx=*/1, feature_workclass, "B", *features); examples->SetNumerical(/*example_idx=*/1, feature_fnlwgt, 2.f, *features); examples->SetCategorical(/*example_idx=*/1, feature_education, "B", *features); examples->SetNumerical(/*example_idx=*/1, feature_education_num, 2.f, *features); examples->SetCategorical(/*example_idx=*/1, feature_marital_status, "B", *features); examples->SetCategorical(/*example_idx=*/1, feature_occupation, "B", *features); examples->SetCategorical(/*example_idx=*/1, feature_relationship, "B", *features); examples->SetCategorical(/*example_idx=*/1, feature_race, "B", *features); examples->SetCategorical(/*example_idx=*/1, feature_sex, "B", *features); examples->SetNumerical(/*example_idx=*/1, feature_capital_gain, 2.f, *features); examples->SetNumerical(/*example_idx=*/1, feature_capital_loss, 2.f, *features); examples->SetNumerical(/*example_idx=*/1, feature_hours_per_week, 2.f, *features); examples->SetCategorical(/*example_idx=*/1, feature_native_country, "B", *features); // Run the model on the two examples. // // For speed-sensitive code, reuse the same predictions. std::vector<float> predictions; engine->Predict(*examples, num_examples, &predictions); return predictions; } inline absl::StatusOr<ServingModel> Load(absl::string_view path) { ServingModel m; // Load the model ASSIGN_OR_RETURN(auto model, serving_api::LoadModel(path)); // Compile the model into an inference engine. ASSIGN_OR_RETURN(m.engine, model->BuildFastEngine()); // Index the input features of the model. m.features = &m.engine->features(); // Index the input features. ASSIGN_OR_RETURN(m.feature_age, m.features->GetNumericalFeatureId("age")); ASSIGN_OR_RETURN(m.feature_workclass, m.features->GetCategoricalFeatureId("workclass")); ASSIGN_OR_RETURN(m.feature_fnlwgt, m.features->GetNumericalFeatureId("fnlwgt")); ASSIGN_OR_RETURN(m.feature_education, m.features->GetCategoricalFeatureId("education")); ASSIGN_OR_RETURN(m.feature_education_num, m.features->GetNumericalFeatureId("education_num")); ASSIGN_OR_RETURN(m.feature_marital_status, m.features->GetCategoricalFeatureId("marital_status")); ASSIGN_OR_RETURN(m.feature_occupation, m.features->GetCategoricalFeatureId("occupation")); ASSIGN_OR_RETURN(m.feature_relationship, m.features->GetCategoricalFeatureId("relationship")); ASSIGN_OR_RETURN(m.feature_race, m.features->GetCategoricalFeatureId("race")); ASSIGN_OR_RETURN(m.feature_sex, m.features->GetCategoricalFeatureId("sex")); ASSIGN_OR_RETURN(m.feature_capital_gain, m.features->GetNumericalFeatureId("capital_gain")); ASSIGN_OR_RETURN(m.feature_capital_loss, m.features->GetNumericalFeatureId("capital_loss")); ASSIGN_OR_RETURN(m.feature_hours_per_week, m.features->GetNumericalFeatureId("hours_per_week")); ASSIGN_OR_RETURN(m.feature_native_country, m.features->GetCategoricalFeatureId("native_country")); return m; } } // namespace exported_model_my_model } // namespace yggdrasil_decision_forests #endif // YGGDRASIL_DECISION_FORESTS_GENERATED_MODEL_my_model
保存模型¶
最后,我们将使用相同的模型以便后续使用。
model.save("/tmp/my_model")
因此,我们可以通过以下方式加载模型:
loaded_model = ydf.load_model("/tmp/my_model")
print(f"This is a {loaded_model.name()} model.")
This is a GRADIENT_BOOSTED_TREES model.
结论¶
这就是所有内容。您已经了解了 YDF 的基本功能 😊。
要了解更多有关 YDF 的信息,请查看 ydf.readthedocs.io 上的其他教程。例如,了解如何:
- 使用
task
参数训练排序、回归或提升模型。 - 使用
model.distance
测量距离并找到示例之间的最近邻。 - 使用
features
参数对特征施加单调约束。 - 在网页中使用 JavaScript 运行模型,使用
model.to_javascript()
。 - 将模型转换为 TensorFlow SavedModel,并在 TensorFlow Serving 中运行,使用
model.to_tensorflow_saved_model()
。 - 使用分布式训练计算在数十亿个训练示例上训练模型。