设置¶
# 安装依赖项
!pip install ydf -U -q
!pip install tensorflow -U -q
!pip install optax pandas numpy -U -q
!pip install jax[cpu] -U
# 或者
# !pip install jax[cuda12] -U -q
# 请参阅 https://jax.readthedocs.io/en/latest/installation.html 以了解 JAX 的不同版本。
import tempfile
import jax
from jax.experimental import jax2tf # 将JAX模型导出为SavedModel
import optax # 微调YDF+JAX模型
import pandas as pd # 我们使用Pandas加载小型数据集。
import tensorflow as tf # 创建保存的模型
import ydf # Yggdrasil决策森林
ds_path = "https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset"
# 下载并加载数据集为Pandas DataFrame
train_ds = pd.read_csv(f"{ds_path}/adult_train.csv")
test_ds = pd.read_csv(f"{ds_path}/adult_test.csv")
label = "income"
# 打印前5个训练样本
train_ds.head(5)
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 44 | Private | 228057 | 7th-8th | 4 | Married-civ-spouse | Machine-op-inspct | Wife | White | Female | 0 | 0 | 40 | Dominican-Republic | <=50K |
1 | 20 | Private | 299047 | Some-college | 10 | Never-married | Other-service | Not-in-family | White | Female | 0 | 0 | 20 | United-States | <=50K |
2 | 40 | Private | 342164 | HS-grad | 9 | Separated | Adm-clerical | Unmarried | White | Female | 0 | 0 | 37 | United-States | <=50K |
3 | 30 | Private | 361742 | Some-college | 10 | Married-civ-spouse | Exec-managerial | Husband | White | Male | 0 | 0 | 50 | United-States | <=50K |
4 | 67 | Self-emp-inc | 171564 | HS-grad | 9 | Married-civ-spouse | Prof-specialty | Wife | White | Female | 20051 | 0 | 30 | England | >50K |
首先,我们在数据集上训练一个YDF模型。
learner = ydf.GradientBoostedTreesLearner(label=label)
model = learner.train(train_ds)
Train model on 22792 examples Model trained in 0:00:02.277830
我们将YDF模型转换为JAX函数。
jax_model = model.to_jax_function()
jax_model
对象包含三个字段。
predict
: 一个用于进行预测的 JAX 函数。encoder
: 一个可调用类,用于准备predict
的示例。由于 JAX 不支持字符串值,因此分类字符串输入特征必须在调用predict
之前进行准备。params
: 一个可选的字典,包含定义模型可微参数的 Jax 数组。默认情况下,params
为 None,且predict
不接受任何参数。我们将在第二部分展示如何使用params
。
我们生成测试集前5个实例的预测。
首先,我们选择一些示例并对其进行编码。
# 从Pandas数据框中选择前5个示例并去除标签。
selected_examples = test_ds[:5].drop(model.label(), axis=1)
# 将示例编码为JAX数组的字典。
jax_selected_examples = jax_model.encoder(selected_examples)
jax_selected_examples
{'age': Array([39, 40, 40, 35, 23], dtype=int32), 'workclass': Array([4, 1, 1, 6, 3], dtype=int32), 'fnlwgt': Array([ 77516, 121772, 193524, 76845, 190709], dtype=int32), 'education': Array([ 3, 5, 13, 11, 7], dtype=int32), 'education_num': Array([13, 11, 16, 5, 12], dtype=int32), 'marital_status': Array([2, 1, 1, 1, 2], dtype=int32), 'occupation': Array([ 4, 3, 1, 10, 12], dtype=int32), 'relationship': Array([2, 1, 1, 1, 2], dtype=int32), 'race': Array([1, 3, 1, 2, 1], dtype=int32), 'sex': Array([1, 1, 1, 1, 1], dtype=int32), 'capital_gain': Array([2174, 0, 0, 0, 0], dtype=int32), 'capital_loss': Array([0, 0, 0, 0, 0], dtype=int32), 'hours_per_week': Array([40, 40, 60, 40, 52], dtype=int32), 'native_country': Array([1, 0, 1, 1, 1], dtype=int32)}
然后,我们生成预测。
jax_predictions = jax_model.predict(jax_selected_examples)
jax_predictions
Array([0.01860434, 0.36130956, 0.83858865, 0.04385566, 0.02917648], dtype=float32)
请注意,JAX函数的预测结果等于YDF模型的预测结果(除去浮点数舍入误差)。
model.predict(selected_examples)
array([0.01860435, 0.36130956, 0.83858865, 0.04385567, 0.02917649], dtype=float32)
JAX不定义模型序列化格式,例如将模型保存到磁盘的方法。相反,要保存一个用于服务的JAX模型,通常会将其导出为SavedModel。
# 创建一个包含模型的TF模块。
tf_model = tf.Module()
tf_model.predict = tf.function(
jax2tf.convert(jax_model.predict, with_gradient=False),
jit_compile=True,
autograph=False,
)
# 检查TF模块的预测结果。
tf_selected_examples = {
k: tf.constant(v) for k, v in jax_selected_examples.items()
}
tf_predictions = tf_model.predict(tf_selected_examples)
tf_predictions
<tf.Tensor: shape=(5,), dtype=float32, numpy= array([0.01860434, 0.36130956, 0.83858865, 0.04385566, 0.02917648], dtype=float32)>
# 将TF模块保存到文件。
with tempfile.TemporaryDirectory() as tempdir:
tf.saved_model.save(tf_model, tempdir)
INFO:tensorflow:Assets written to: /tmp/tmp90flesgr/assets
INFO:tensorflow:Assets written to: /tmp/tmp90flesgr/assets
信息: YDF的to_tensorflow_saved_model
函数允许直接创建SavedModel模型。这种方法产生更快的模型,但需要安装TensorFlow决策森林。
try:
with tempfile.TemporaryDirectory() as tempdir:
# 直接将YDF模型保存为SavedModel格式。
model.to_tensorflow_saved_model(tempdir, mode="tf")
except Exception as e:
print("Could not save YDF model to SavedModel with to_tensorflow_saved_model")
[INFO 24-06-14 14:31:56.6553 CEST kernel.cc:1233] Loading model from path /tmp/tmp71lnhoy9/tmp83xu8mjt/ with prefix e57777e0_ [INFO 24-06-14 14:31:56.6795 CEST quick_scorer_extended.cc:911] The binary was compiled without AVX2 support, but your CPU supports it. Enable it for faster model inference. [INFO 24-06-14 14:31:56.6803 CEST abstract_model.cc:1362] Engine "GradientBoostedTreesQuickScorerExtended" built [INFO 24-06-14 14:31:56.6803 CEST kernel.cc:1061] Use fast generic engine
INFO:tensorflow:Assets written to: /tmp/tmpi0fp69xz/assets
INFO:tensorflow:Assets written to: /tmp/tmpi0fp69xz/assets
用JAX微调YDF模型¶
当关注的示例(服务示例)与训练数据集遵循不同的分布时,会发生分布偏移问题。例如,在医院中,当在不同设备获取的数据上训练模型时,就会发生分布偏移。尽管来自不同设备的数据集应该是兼容的,但它们之间的微小差异使得在一个数据集上训练的模型在另一个数据集上的表现不佳。例如,训练在某个设备捕获的图像上检测肿瘤的机器学习模型,可能在另一品牌设备捕获的图像上效果不佳。分布偏移在随时间变化的动态系统中也很常见(例如,用户行为)。
在本节中,我们将通过微调来解决一个分布偏移问题。为此,我们使用略有不同的成人数据集。我们假设只有“关系=妻子”的人是我们感兴趣的对象。然而,这一类人只有5%,因此我们只有少量的训练示例。
我们首先会观察到,仅在关系=妻子
的示例上进行训练或在所有可用示例上进行训练都不会产生最佳模型。相反,我们将首先在所有示例上训练一个YDF模型,然后使用JAX在关系=妻子
的示例上进行微调,并观察到这个微调后的模型表现更好。最后,微调后的JAX模型将被转换回YDF模型,并使用YDF工具进行分析。
首先,让我们打印测试示例中 relationship
的分布。我们的目标是优化模型在 483 个 relationship == Wife
示例上的质量。
test_ds["relationship"].value_counts()
relationship Husband 4002 Not-in-family 2505 Own-child 1521 Unmarried 948 Wife 483 Other-relative 310 Name: count, dtype: int64
我们将数据集分为两组:A 组包含 relationship != Wife
的示例,B 组包含 relationship == Wife
的示例。
def is_group_B(ds):
return ds["relationship"] == "Wife"
train_ds_group_A = train_ds[~is_group_B(train_ds)]
test_ds_group_A = test_ds[~is_group_B(test_ds)]
train_ds_group_B = train_ds[is_group_B(train_ds)]
test_ds_group_B = test_ds[is_group_B(test_ds)]
print("Number of examples per group")
print("\tTrain Group A:", len(train_ds_group_A))
print("\tTest Group A:", len(test_ds_group_A))
print("\tTrain Group B:", len(train_ds_group_B))
print("\tTest Group B:", len(test_ds_group_B))
Number of examples per group Train Group A: 21707 Test Group A: 9286 Train Group B: 1085 Test Group B: 483
请注意,组 A 包含的示例数量比组 B 多,但我们关注的是组 B 中的测试示例。
让我们在组 A 和 B 的不同组合上训练和评估三个模型。这将是我们的基线。
# 在A组上训练模型
model_group_A = ydf.GradientBoostedTreesLearner(label=label).train(
train_ds_group_A, verbose=0
)
# 在B组上训练模型
model_group_B = ydf.GradientBoostedTreesLearner(label=label).train(
train_ds_group_B, verbose=0
)
# 在A组上训练模型 + B
model_group_AB = ydf.GradientBoostedTreesLearner(label=label).train(
train_ds, verbose=0
)
# 评估B组模型
accuracy_test_B_model_A = model_group_A.evaluate(test_ds_group_B).accuracy
accuracy_test_B_model_B = model_group_B.evaluate(test_ds_group_B).accuracy
accuracy_test_B_model_AB = model_group_AB.evaluate(test_ds_group_B).accuracy
print("Accuracy on B, model trained on A:", accuracy_test_B_model_A)
print("Accuracy on B, model trained on B:", accuracy_test_B_model_B)
print("Accuracy on B, model trained on A+B:", accuracy_test_B_model_AB)
Accuracy on B, model trained on A: 0.7204968944099379 Accuracy on B, model trained on B: 0.7329192546583851 Accuracy on B, model trained on A+B: 0.7556935817805382
在A组和B组上训练的模型在B组上的表现最好。我们能做得更好吗?
让我们将训练在A+B上的模型转换为JAX函数。
jax_model_group_AB = model_group_AB.to_jax_function(
apply_activation=False,
leaves_as_params=True,
)
jax_model_group_AB.params
{'leaf_values': Array([-0.1233467 , -0.0927111 , 0.2927755 , ..., 0.05464426, 0.12556875, -0.11374608], dtype=float32), 'initial_predictions': Array([-1.1630996], dtype=float32)}
注意:
apply_activation=True
会从模型中移除激活函数。这使得模型损失可以在 logits 而不是概率上计算,从而使微调在数值上更加稳定。leaves_as_params=True
指定留存值作为模型参数导出到params
中。这对于微调模型是必要的。
为了微调模型,我们需要生成例子的批次。以下代码块生成这样的批次。
def get_num_examples(ds):
return len(next(iter(ds.values())))
def prepare_dataset(ds, jax_model, batch=100):
ds = ds.copy()
# 将标签设为布尔值
ds[label] = ds[label] == ">50K"
# 对输入特征进行编码
encoded_ds = jax_model.encoder(ds)
# 生成一批示例
n = get_num_examples(encoded_ds)
i = 0
while i < n:
begin_idx = i
end_idx = min(i + batch, n)
yield {k: v[begin_idx:end_idx] for k, v in encoded_ds.items()}
i += batch
# Example of utilisation of "prepare_dataset".
for examples in prepare_dataset(train_ds_group_B, jax_model_group_AB, batch=4):
print(examples)
break # 我们只印刷了第一批
{'age': Array([44, 67, 26, 30], dtype=int32), 'workclass': Array([1, 5, 0, 1], dtype=int32), 'fnlwgt': Array([228057, 171564, 167835, 118551], dtype=int32), 'education': Array([9, 1, 3, 3], dtype=int32), 'education_num': Array([ 4, 9, 13, 13], dtype=int32), 'marital_status': Array([1, 1, 1, 1], dtype=int32), 'occupation': Array([ 7, 1, 0, 11], dtype=int32), 'relationship': Array([5, 5, 5, 5], dtype=int32), 'race': Array([1, 1, 1, 1], dtype=int32), 'sex': Array([2, 2, 2, 2], dtype=int32), 'capital_gain': Array([ 0, 20051, 0, 0], dtype=int32), 'capital_loss': Array([0, 0, 0, 0], dtype=int32), 'hours_per_week': Array([40, 30, 20, 16], dtype=int32), 'native_country': Array([12, 10, 1, 1], dtype=int32), 'income': Array([False, True, False, True], dtype=bool)}
我们来定义一些工具,以计算和打印模型的损失和准确率。
@jax.jit
def compute_accuracy(params, examples, logit=True):
examples = examples.copy()
labels = examples.pop(model.label())
predictions = jax_model_group_AB.predict(examples, params)
return ((predictions >= 0.0) == labels).mean()
@jax.jit
def compute_loss(params, examples):
examples = examples.copy()
labels = examples.pop(model.label())
logits = jax_model_group_AB.predict(examples, params)
return optax.sigmoid_binary_cross_entropy(logits, labels).mean()
def compute_metric(metric_fn, ds):
sum_metrics = 0
num_examples = 0
for examples in prepare_dataset(ds, jax_model_group_AB):
n = get_num_examples(examples)
sum_metrics += n * metric_fn(jax_model_group_AB.params, examples)
num_examples += n
return float(sum_metrics / num_examples)
def print_logs(stage):
train_accuracy = compute_metric(compute_accuracy, train_ds_group_B)
train_loss = compute_metric(compute_loss, train_ds_group_B)
test_accuracy = compute_metric(compute_accuracy, test_ds_group_B)
test_loss = compute_metric(compute_loss, test_ds_group_B)
print(
f"stage:{stage:10} "
f"test-accuracy:{test_accuracy:.5f} test-loss:{test_loss:.5f} "
f"train-accuracy:{train_accuracy:.5f} train-loss:{train_loss:.5f}"
)
# 模型训练前的指标。
print_logs("initial")
stage:initial test-accuracy:0.75569 test-loss:0.47798 train-accuracy:0.83963 train-loss:0.37099
以下是训练循环。
optimizer = optax.adam(0.001)
@jax.jit
def train_step(opt_state, mdl_state, examples):
loss, grads = jax.value_and_grad(compute_loss)(mdl_state, examples)
updates, opt_state = optimizer.update(grads, opt_state)
mdl_state = optax.apply_updates(mdl_state, updates)
return opt_state, mdl_state, loss
opt_state = optimizer.init(jax_model_group_AB.params)
for epoch_idx in range(10):
print_logs(f"epoch_{epoch_idx}")
for examples in prepare_dataset(train_ds_group_B, jax_model_group_AB):
opt_state, jax_model_group_AB.params, _ = train_step(
opt_state, jax_model_group_AB.params, examples
)
print_logs("final")
stage:epoch_0 test-accuracy:0.75569 test-loss:0.47798 train-accuracy:0.83963 train-loss:0.37099 stage:epoch_1 test-accuracy:0.75155 test-loss:0.48035 train-accuracy:0.84424 train-loss:0.36520 stage:epoch_2 test-accuracy:0.75776 test-loss:0.47823 train-accuracy:0.84240 train-loss:0.35878 stage:epoch_3 test-accuracy:0.75983 test-loss:0.48016 train-accuracy:0.84608 train-loss:0.35352 stage:epoch_4 test-accuracy:0.75776 test-loss:0.48063 train-accuracy:0.84793 train-loss:0.34862 stage:epoch_5 test-accuracy:0.75569 test-loss:0.48173 train-accuracy:0.85069 train-loss:0.34419 stage:epoch_6 test-accuracy:0.75776 test-loss:0.48283 train-accuracy:0.85346 train-loss:0.34008 stage:epoch_7 test-accuracy:0.75776 test-loss:0.48381 train-accuracy:0.85806 train-loss:0.33622 stage:epoch_8 test-accuracy:0.75983 test-loss:0.48495 train-accuracy:0.86175 train-loss:0.33260 stage:epoch_9 test-accuracy:0.75983 test-loss:0.48595 train-accuracy:0.86267 train-loss:0.32917 stage:final test-accuracy:0.75983 test-loss:0.48703 train-accuracy:0.86359 train-loss:0.32592
注意到测试和训练准确率在训练过程中都有所提升。
现在我们可以用微调后的权重更新YDF模型。
model_group_AB.update_with_jax_params(jax_model_group_AB.params)
model_group_AB
是微调后的模型。让我们评估并将其与其他模型进行比较:
accuracy_test_B_model_AB_finetuned_B = model_group_AB.evaluate(
test_ds_group_B
).accuracy
print("Accuracy on B, model trained on A:", accuracy_test_B_model_A)
print("Accuracy on B, model trained on B:", accuracy_test_B_model_B)
print("Accuracy on B, model trained on A+B:", accuracy_test_B_model_AB)
print("==================================")
print(
"Accuracy on B, model trained on A+B, finetuned on B:",
accuracy_test_B_model_AB_finetuned_B,
)
Accuracy on B, model trained on A: 0.7204968944099379 Accuracy on B, model trained on B: 0.7329192546583851 Accuracy on B, model trained on A+B: 0.7556935817805382 ================================== Accuracy on B, model trained on A+B, finetuned on B: 0.7598343685300207
注意到新的模型“在B上的准确性,模型在A+B上训练”显示了最佳测试准确性。
model_group_AB
是一个YDF模型,和其他模型一样。例如,你可以保存它并进行分析。
# 保存模型
with tempfile.TemporaryDirectory() as tempdir:
model_group_AB.save(tempdir)
# 分析模型
model_group_AB.analyze(test_ds_group_B)
Variable importances measure the importance of an input feature for a model.
1. "capital_gain" 0.049689 ################ 2. "occupation" 0.045549 ############## 3. "education" 0.026915 ######## 4. "education_num" 0.026915 ######## 5. "age" 0.018634 ###### 6. "capital_loss" 0.018634 ###### 7. "workclass" 0.014493 ##### 8. "fnlwgt" 0.002070 # 9. "native_country" 0.002070 # 10. "relationship" 0.000000 11. "race" 0.000000 12. "sex" 0.000000 13. "hours_per_week" 0.000000 14. "marital_status" -0.002070
1. "capital_gain" 0.164288 ################ 2. "capital_loss" 0.048263 ##### 3. "occupation" 0.033196 ### 4. "education" 0.023903 ## 5. "education_num" 0.015137 ## 6. "age" 0.013872 # 7. "workclass" 0.006274 # 8. "race" 0.002477 9. "sex" 0.001453 10. "fnlwgt" 0.000984 11. "marital_status" 0.000722 12. "relationship" 0.000000 13. "native_country" -0.000019 14. "hours_per_week" -0.007143
1. "capital_gain" 0.083385 ################ 2. "occupation" 0.040765 ######## 3. "capital_loss" 0.030647 ###### 4. "education" 0.026051 ##### 5. "age" 0.024419 ##### 6. "education_num" 0.016887 #### 7. "workclass" 0.010427 ## 8. "race" 0.003161 # 9. "marital_status" 0.000790 # 10. "sex" 0.000704 # 11. "relationship" 0.000000 # 12. "native_country" -0.000361 # 13. "fnlwgt" -0.001022 14. "hours_per_week" -0.006107
1. "capital_gain" 0.162868 ################ 2. "capital_loss" 0.048043 ##### 3. "occupation" 0.033135 ### 4. "education" 0.023881 ## 5. "education_num" 0.015116 ## 6. "age" 0.013875 # 7. "workclass" 0.006275 # 8. "race" 0.002472 9. "sex" 0.001448 10. "fnlwgt" 0.000990 11. "marital_status" 0.000721 12. "relationship" 0.000000 13. "native_country" -0.000014 14. "hours_per_week" -0.007106
1. "age" 0.226642 ################ 2. "occupation" 0.219727 ############# 3. "capital_gain" 0.214876 ############ 4. "education" 0.213746 ########### 5. "marital_status" 0.212739 ########### 6. "relationship" 0.206040 ######### 7. "fnlwgt" 0.203843 ######## 8. "hours_per_week" 0.203735 ######## 9. "capital_loss" 0.196549 ###### 10. "native_country" 0.190548 #### 11. "workclass" 0.187795 ### 12. "education_num" 0.184215 ## 13. "race" 0.180495 14. "sex" 0.177647
1. "age" 26.000000 ################ 2. "capital_gain" 26.000000 ################ 3. "marital_status" 20.000000 ############ 4. "relationship" 17.000000 ########## 5. "capital_loss" 14.000000 ######## 6. "hours_per_week" 14.000000 ######## 7. "education" 12.000000 ####### 8. "fnlwgt" 10.000000 ##### 9. "race" 9.000000 ##### 10. "education_num" 7.000000 ### 11. "sex" 4.000000 # 12. "occupation" 2.000000 13. "workclass" 1.000000 14. "native_country" 1.000000
1. "occupation" 724.000000 ################ 2. "fnlwgt" 513.000000 ########### 3. "age" 483.000000 ########## 4. "education" 464.000000 ########## 5. "hours_per_week" 339.000000 ####### 6. "capital_gain" 326.000000 ###### 7. "native_country" 306.000000 ###### 8. "capital_loss" 297.000000 ###### 9. "relationship" 262.000000 ##### 10. "workclass" 244.000000 ##### 11. "marital_status" 210.000000 #### 12. "education_num" 82.000000 # 13. "sex" 42.000000 14. "race" 21.000000
1. "relationship" 3014.690076 ################ 2. "capital_gain" 2065.521668 ########## 3. "education" 1144.490954 ###### 4. "marital_status" 1111.389695 ##### 5. "occupation" 1094.619502 ##### 6. "education_num" 796.666823 #### 7. "capital_loss" 584.055066 ### 8. "age" 582.288569 ### 9. "hours_per_week" 366.856509 # 10. "native_country" 263.872689 # 11. "fnlwgt" 216.537764 # 12. "workclass" 196.085503 # 13. "sex" 47.217730 14. "race" 5.428727