代码示例 / 结构化数据 / 使用 TensorFlow Decision Forests 进行分类

使用 TensorFlow Decision Forests 进行分类

作者: Khalid Salama
创建日期: 2022/01/25
最后修改: 2022/01/25
描述: 使用 TensorFlow Decision Forests 进行结构化数据分类。

在 Colab 中查看 GitHub 源代码


介绍

TensorFlow Decision Forests 是一组先进的决策森林模型算法,兼容 Keras API。这些模型包括 随机森林梯度提升树CART,可用于回归、分类和排序任务。有关 TensorFlow Decision Forests 的初学者指南,请参阅此 教程

此示例使用梯度提升树模型对结构化数据进行二分类,并涵盖以下场景:

  1. 通过指定输入特征的使用情况来构建决策森林模型。
  2. 实现一个自定义的 二元目标编码器 作为 Keras 预处理层,根据目标值共现对类别特征进行编码,然后使用编码后的特征构建决策森林模型。
  3. 将类别特征编码为 嵌入,在简单的神经网络模型中训练这些嵌入,然后使用训练好的嵌入作为输入构建决策森林模型。

此示例使用 TensorFlow 2.7 或更高版本,以及 TensorFlow Decision Forests,你可以使用以下命令安装:

pip install -U tensorflow_decision_forests

设置

import math
import urllib
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_decision_forests as tfdf

准备数据

此示例使用 美国人口普查收入数据集,该数据集由 加州大学尔湾分校机器学习库 提供。任务是进行二分类,以确定一个人是否年收入超过 50K。

数据集包含约 30 万个实例和 41 个输入特征:7 个数值特征和 34 个类别特征。

首先,我们将数据从 UCI 机器学习库加载到 Pandas DataFrame 中。

BASE_PATH = "https://kdd.ics.uci.edu/databases/census-income/census-income"
CSV_HEADER = [
    l.decode("utf-8").split(":")[0].replace(" ", "_")
    for l in urllib.request.urlopen(f"{BASE_PATH}.names")
    if not l.startswith(b"|")
][2:]
CSV_HEADER.append("income_level")

train_data = pd.read_csv(f"{BASE_PATH}.data.gz", header=None, names=CSV_HEADER,)
test_data = pd.read_csv(f"{BASE_PATH}.test.gz", header=None, names=CSV_HEADER,)

定义数据集的元数据

在这里,我们定义数据集的元数据,这将有助于根据特征类型对输入特征进行编码。

# 目标列名称。
TARGET_COLUMN_NAME = "income_level"
# 目标列的标签。
TARGET_LABELS = [" - 50000.", " 50000+."]
# 权重列名称。
WEIGHT_COLUMN_NAME = "instance_weight"
# 数值特征名称。
NUMERIC_FEATURE_NAMES = [
    "age",
    "wage_per_hour",
    "capital_gains",
    "capital_losses",
    "dividends_from_stocks",
    "num_persons_worked_for_employer",
    "weeks_worked_in_year",
]
# 分类特征及其词汇表。
CATEGORICAL_FEATURE_NAMES = [
    "class_of_worker",
    "detailed_industry_recode",
    "detailed_occupation_recode",
    "education",
    "enroll_in_edu_inst_last_wk",
    "marital_stat",
    "major_industry_code",
    "major_occupation_code",
    "race",
    "hispanic_origin",
    "sex",
    "member_of_a_labor_union",
    "reason_for_unemployment",
    "full_or_part_time_employment_stat",
    "tax_filer_stat",
    "region_of_previous_residence",
    "state_of_previous_residence",
    "detailed_household_and_family_stat",
    "detailed_household_summary_in_household",
    "migration_code-change_in_msa",
    "migration_code-change_in_reg",
    "migration_code-move_within_reg",
    "live_in_this_house_1_year_ago",
    "migration_prev_res_in_sunbelt",
    "family_members_under_18",
    "country_of_birth_father",
    "country_of_birth_mother",
    "country_of_birth_self",
    "citizenship",
    "own_business_or_self_employed",
    "fill_inc_questionnaire_for_veteran's_admin",
    "veterans_benefits",
    "year",
]

现在我们执行基本数据准备。

def prepare_dataframe(dataframe):
    # 将目标标签从字符串转换为整数。
    dataframe[TARGET_COLUMN_NAME] = dataframe[TARGET_COLUMN_NAME].map(
        TARGET_LABELS.index
    )
    # 将类别特征转换为字符串。
    for feature_name in CATEGORICAL_FEATURE_NAMES:
        dataframe[feature_name] = dataframe[feature_name].astype(str)


prepare_dataframe(train_data)
prepare_dataframe(test_data)

现在让我们显示训练和测试数据框的形状,并显示一些实例。

print(f"Train data shape: {train_data.shape}")
print(f"Test data shape: {test_data.shape}")
print(train_data.head().T)
Train data shape: (199523, 42)
Test data shape: (99762, 42)
                                                                                    0  \
age                                                                                73   
class_of_worker                                                       不在范围内   
detailed_industry_recode                                                            0   
detailed_occupation_recode                                                          0   
education                                                        高中毕业   
wage_per_hour                                                                       0   
enroll_in_edu_inst_last_wk                                            不在范围内   
marital_stat                                                                  寡居   
major_industry_code                                       不在范围内或儿童   
major_occupation_code                                                 不在范围内   
race                                                                            白人   
hispanic_origin                                                             所有其他   
sex                                                                            女性   
member_of_a_labor_union                                               不在范围内   
reason_for_unemployment                                               不在范围内   
full_or_part_time_employment_stat                                  不在劳动力中   
capital_gains                                                                       0   
capital_losses                                                                      0   
dividends_from_stocks                                                               0   
tax_filer_stat                                                               非报税人   
region_of_previous_residence                                          不在范围内   
state_of_previous_residence                                           不在范围内   
detailed_household_and_family_stat           其他亲属 18岁以上曾婚不在子家庭中   
detailed_household_summary_in_household                 其他户主的亲属   
instance_weight                                                               1700.09   
migration_code-change_in_msa                                                        ?   
migration_code-change_in_reg                                                        ?   
migration_code-move_within_reg                                                      ?   
live_in_this_house_1_year_ago                        不在范围内 不到1岁   
migration_prev_res_in_sunbelt                                                       ?   
num_persons_worked_for_employer                                                     0   
family_members_under_18                                               不在范围内   
country_of_birth_father                                                 美国   
country_of_birth_mother                                                 美国   
country_of_birth_self                                                   美国   
citizenship                                         本土- 生于美国   
own_business_or_self_employed                                                       0   
fill_inc_questionnaire_for_veteran's_admin                            不在范围内   
veterans_benefits                                                                   2   
weeks_worked_in_year                                                                0   
year                                                                               95   
income_level                                                                        0   
                                                                           1  \

age 58
class_of_worker 自雇-未注册
detailed_industry_recode 4
detailed_occupation_recode 34
education 某些大学课程但没有学位
wage_per_hour 0
enroll_in_edu_inst_last_wk 不属于该范围
marital_stat 离婚
major_industry_code 建筑业
major_occupation_code 精密制造工艺与维修
race 白人
hispanic_origin 所有其他
sex 男性
member_of_a_labor_union 不属于该范围
reason_for_unemployment 不属于该范围
full_or_part_time_employment_stat 孩子或武装部队
capital_gains 0
capital_losses 0
dividends_from_stocks 0
tax_filer_stat 家庭户主
region_of_previous_residence 南部
state_of_previous_residence 阿肯色州
detailed_household_and_family_stat 户主
detailed_household_summary_in_household 户主
instance_weight 1053.55
migration_code-change_in_msa MSA到MSA
migration_code-change_in_reg 同一县
migration_code-move_within_reg 同一县
live_in_this_house_1_year_ago 否
migration_prev_res_in_sunbelt 是
num_persons_worked_for_employer 1
family_members_under_18 不属于该范围
country_of_birth_father 美国
country_of_birth_mother 美国
country_of_birth_self 美国
citizenship 以美国出生的公民
own_business_or_self_employed 0
fill_inc_questionnaire_for_veteran's_admin 不属于该范围
veterans_benefits 2
weeks_worked_in_year 52
year 94
income_level 0

                                                                                   2  \
年龄                                                                               18   
工作类别                                                      不在统计范围内   
详细行业编码                                                               0   
详细职业编码                                                               0   
教育程度                                                                 10年级   
每小时工资                                                                      0   
上周是否在教育机构就读                                               高中   
婚姻状况                                                           未婚   
主要行业代码                                      不在统计范围内或儿童   
主要职业代码                                                不在统计范围内   
种族                                                       亚裔或太平洋岛民   
西班牙裔身份                                                            其他所有   
性别                                                                           女性   
是否为工会成员                                              不在统计范围内   
失业原因                                              不在统计范围内   
全职或兼职就业状态                                 不在劳动力市场中   
资本收益                                                                      0   
资本损失                                                                     0   
股票分红                                                              0   
纳税人状态                                                              非纳税人   
前居住地区                                         不在统计范围内   
前居住州                                          不在统计范围内   
详细家庭和家庭状态           18岁及以上儿童未婚未形成小家庭   
详细家庭摘要                                                18岁或以上的儿童   
实例权重                                                               991.95   
迁移代码-变化的MSA                                                       ?   
迁移代码-变化的地区                                                       ?   
迁移代码-在地区内移动                                                     ?   
一年前是否居住在此房屋                       不在统计范围内,未满1岁   
迁移前居住在阳光带                                                      ?   
为雇主工作的人数                                                    0   
18岁以下家庭成员                                              不在统计范围内   
父亲出生国家                                                      越南   
母亲出生国家                                                      越南   
本人出生国家                                                        越南   
国籍                                      外国出生- 非美国公民    
自营或自主经营                                                      0   
为退伍军人事务部填写收入问卷                           不在统计范围内   
退伍军人福利                                                                  2   
一年工作周数                                                               0   
年份                                                                              95   
收入水平                                                                       0   
                                                                                 3  \
年龄                                                                              9   
工作类别                                                      不在定义范围内   
详细行业代码                                                           0   
详细职业代码                                                             0   
教育程度                                                                 儿童   
每小时工资                                                                    0   
上周教育机构注册情况                                       不在定义范围内   
婚姻状况                                                         从未结婚   
主要行业代码                                    不在定义范围内或儿童   
主要职业代码                                              不在定义范围内   
种族                                                                         白人   
西班牙裔来源                                                          其他所有   
性别                                                                         女性   
是否为工会成员                                            不在定义范围内   
失业原因                                            不在定义范围内   
全职或兼职工作状态                         儿童或武装力量   
资本收益                                                                    0   
资本损失                                                                   0   
股票分红                                                            0   
报税状态                                                            非报税者   
之前居住地区                                       不在定义范围内   
之前居住州                                        不在定义范围内   
详细家庭和家庭成员状态           18岁以下儿童从未结婚且不在子家庭中   
详细家庭摘要在户内               18岁以下儿童从未结婚   
个案权重                                                            1758.14   
迁移代码-大都市区变化                                              非迁移者   
迁移代码-地区变化                                              非迁移者   
迁移代码-区内移动                                            非迁移者   
一年之前住在这所房子里                                                  是   
迁移前在阳光带的居住情况                                      不在定义范围内   
为雇主工作的人数                                                  0   
18岁以下家庭成员                                       双亲在场   
父亲出生国                                              美国   
母亲出生国                                              美国   
本人出生国                                                美国   
国籍                                      本土- 在美国出生   
拥有自己的生意或自雇                                                    0   
为退伍军人事务部填写收入问卷                         不在定义范围内   
退伍军人福利                                                                0   
年度工作周数                                                             0   
年份                                                                            94   
收入水平                                                                     0   
                                                                                 4  
年龄                                                                             10  
工作类别                                                    不在宇宙中  
详细行业编码                                                         0  
详细职业编码                                                       0  
教育                                                                 儿童  
每小时工资                                                                    0  
上周参加教育机构                                                     不在宇宙中  
婚姻状况                                                         从未结婚  
主要行业代码                                    不在宇宙中或儿童  
主要职业代码                                              不在宇宙中  
种族                                                                         白人  
西班牙裔来源                                                          所有其他  
性别                                                                         女性  
工会成员                                            不在宇宙中  
失业原因                                            不在宇宙中  
全职或兼职就业状态                         儿童或军队  
资本收益                                                                    0  
资本损失                                                                   0  
股票的股息                                                            0  
报税状态                                                            非申报者  
以前居住地区                                       不在宇宙中  
以前居住州                                        不在宇宙中  
详细家庭和家庭统计           18岁以下儿童从未结过婚不在子家庭中  
详细家庭摘要               18岁以下儿童从未结婚  
实例权重                                                            1069.16  
迁移代码-城市区变化                                              不迁移  
迁移代码-地区变化                                              不迁移  
迁移代码-区域内移动                                            不迁移  
一年前住在这个房子里                                                  是  
迁移前在阳光带的居住地                                      不在宇宙中  
为雇主工作的人数                                                  0  
18岁以下家庭成员                                       双亲在场  
父亲出生国家                                              美国  
母亲出生国家                                              美国  
本人出生国家                                                美国  
国籍                                      本土- 在美国出生  
拥有业务或自雇                                                    0  
为退伍军人事务管理局填写收入问卷                         不在宇宙中  
退伍军人福利                                                                0  
一年内工作周数                                                             0  
年份                                                                            94  
收入水平                                                                     0  

配置超参数

您可以在 文档 中找到梯度提升树模型的所有参数

# 决策树的最大数量。如果启用了提前停止,则有效训练的树的数量可能更小。
NUM_TREES = 250
# 节点中的最小示例数量。
MIN_EXAMPLES = 6
# 树的最大深度。max_depth=1表示所有树都是根。
MAX_DEPTH = 5
# 用于训练个别树的随机抽样方法的数据集比例(不替换抽样)。
SUBSAMPLE = 0.65
# 控制用于训练个别树的数据集的抽样。
SAMPLING_METHOD = "RANDOM"
# 用于监控训练的训练数据集比例。如果启用了提前停止,则需要大于0。
VALIDATION_RATIO = 0.1

实现训练和评估程序

run_experiment() 方法负责加载训练和测试数据集, 训练给定模型并评估训练后的模型。

请注意,在训练决策森林模型时,仅需一个周期即可 读取完整数据集。任何额外步骤都将导致不必要的训练速度减慢。 因此,在run_experiment()方法中使用默认值num_epochs=1

def run_experiment(model, train_data, test_data, num_epochs=1, batch_size=None):

    train_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(
        train_data, label=TARGET_COLUMN_NAME, weight=WEIGHT_COLUMN_NAME
    )
    test_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(
        test_data, label=TARGET_COLUMN_NAME, weight=WEIGHT_COLUMN_NAME
    )

    model.fit(train_dataset, epochs=num_epochs, batch_size=batch_size)
    _, accuracy = model.evaluate(test_dataset, verbose=0)
    print(f"测试准确率: {round(accuracy * 100, 2)}%")

实验 1: 使用原始特征的决策森林

指定模型输入特征的使用

您可以为每个特征附加语义,以控制它被模型如何使用。 如果未指定,则语义将从表示类型中推断得出。 建议明确指定特征使用 以避免错误推断的语义。 例如,分类值标识符(整数)将被推断为数值, 而实际上它是语义上的分类。

对于数值特征,您可以设置discretized参数为数值特征应被离散化的桶数量。 这使得训练更快,但可能导致模型效果更差。

def specify_feature_usages():
    feature_usages = []

    for feature_name in NUMERIC_FEATURE_NAMES:
        feature_usage = tfdf.keras.FeatureUsage(
            name=feature_name, semantic=tfdf.keras.FeatureSemantic.NUMERICAL
        )
        feature_usages.append(feature_usage)

    for feature_name in CATEGORICAL_FEATURE_NAMES:
        feature_usage = tfdf.keras.FeatureUsage(
            name=feature_name, semantic=tfdf.keras.FeatureSemantic.CATEGORICAL
        )
        feature_usages.append(feature_usage)

    return feature_usages

创建一个梯度提升树模型

在编译一个决策森林模型时,您只能提供额外的评估指标。 损失在模型构建中指定, 与决策森林模型无关的优化器。

def create_gbt_model():
    # 查看所有模型参数 https://www.tensorflow.org/decision_forests/api_docs/python/tfdf/keras/GradientBoostedTreesModel
    gbt_model = tfdf.keras.GradientBoostedTreesModel(
        features=specify_feature_usages(),
        exclude_non_specified_features=True,
        num_trees=NUM_TREES,
        max_depth=MAX_DEPTH,
        min_examples=MIN_EXAMPLES,
        subsample=SUBSAMPLE,
        validation_ratio=VALIDATION_RATIO,
        task=tfdf.keras.Task.CLASSIFICATION,
    )

    gbt_model.compile(metrics=[keras.metrics.BinaryAccuracy(name="accuracy")])
    return gbt_model

训练和评估模型

gbt_model = create_gbt_model()
run_experiment(gbt_model, train_data, test_data)
开始读取数据集
200/200 [==============================] - ETA: 0s
数据集读取时间为 0:00:08.829036
训练模型
模型训练时间为 0:00:48.639771
编译模型
200/200 [==============================] - 58s 268ms/step
测试准确率: 95.79%

检查模型

model.summary()方法将显示有关 您的决策树模型、模型类型、任务、输入特征和特征重要性的几种信息。

print(gbt_model.summary())
模型: "gradient_boosted_trees_model"
_________________________________________________________________
 层 (类型)                输出形状              参数数量   
=================================================================
=================================================================
总参数: 1
可训练参数: 0
不可训练参数: 1
_________________________________________________________________
类型: "GRADIENT_BOOSTED_TREES"
任务: 分类
标签: "__LABEL"
输入特征 (40):
年龄
资本增益
资本损失
国籍
工作类别
父亲出生国家
母亲出生国家
本人出生国家
详细家庭和家庭状况统计
详细家庭摘要在家庭中
详细行业编码
详细职业编码
股票分红
教育
上周注册教育机构
18岁以下家庭成员
填写退伍军人管理局的收入问卷
全职或兼职就业状态
西班牙裔来源
一年前住在此房屋
主要行业代码
主要职业代码
婚姻状况
工会成员
迁移代码 - 大都市区变化
迁移代码 - 地区变化
迁移代码 - 区域内移动
迁移之前居住在阳光地带
为雇主工作的人员数量
拥有企业或自雇
种族
失业原因
之前居住地区
性别
之前居住州
纳税申报状态
退伍军人福利
每小时工资
年度工作周数
年份
使用权重进行训练
变量重要性:平均最小深度:
    1.                 "enroll_in_edu_inst_last_wk"  3.942647 ################
    2.                    "family_members_under_18"  3.942647 ################
    3.              "live_in_this_house_1_year_ago"  3.942647 ################
    4.               "migration_code-change_in_msa"  3.942647 ################
    5.             "migration_code-move_within_reg"  3.942647 ################
    6.                                       "year"  3.942647 ################
    7.                                    "__LABEL"  3.942647 ################
    8.                                  "__WEIGHTS"  3.942647 ################
    9.                                "citizenship"  3.942137 ###############
   10.    "detailed_household_summary_in_household"  3.942137 ###############
   11.               "region_of_previous_residence"  3.942137 ###############
   12.                          "veterans_benefits"  3.942137 ###############
   13.              "migration_prev_res_in_sunbelt"  3.940135 ###############
   14.               "migration_code-change_in_reg"  3.939926 ###############
   15.                      "major_occupation_code"  3.937681 ###############
   16.                        "major_industry_code"  3.933687 ###############
   17.                    "reason_for_unemployment"  3.926320 ###############
   18.                            "hispanic_origin"  3.900776 ###############
   19.                    "member_of_a_labor_union"  3.894843 ###############
   20.                                       "race"  3.878617 ###############
   21.            "num_persons_worked_for_employer"  3.818566 ##############
   22.                               "marital_stat"  3.795667 ##############
   23.          "full_or_part_time_employment_stat"  3.795431 ##############
   24.                    "country_of_birth_mother"  3.787967 ##############
   25.                             "tax_filer_stat"  3.784505 ##############
   26. "fill_inc_questionnaire_for_veteran's_admin"  3.783607 ##############
   27.              "own_business_or_self_employed"  3.776398 ##############
   28.                    "country_of_birth_father"  3.715252 #############
   29.                                        "sex"  3.708745 #############
   30.                            "class_of_worker"  3.688424 #############
   31.                       "weeks_worked_in_year"  3.665290 #############
   32.                "state_of_previous_residence"  3.657234 #############
   33.                      "country_of_birth_self"  3.654377 #############
   34.                                        "age"  3.634295 ############
   35.                              "wage_per_hour"  3.617817 ############
   36.         "detailed_household_and_family_stat"  3.594743 ############
   37.                             "capital_losses"  3.439298 ##########
   38.                      "dividends_from_stocks"  3.423652 ##########
   39.                              "capital_gains"  3.222753 ########
   40.                                  "education"  3.158698 ########
   41.                   "detailed_industry_recode"  2.981471 ######
   42.                 "detailed_occupation_recode"  2.364817 
变量重要性:NUM_AS_ROOT:
    1.                                  "教育" 33.000000 ################
    2.                              "资本收益" 29.000000 ##############
    3.                             "资本损失" 24.000000 ###########
    4.         "详细家庭和家庭统计" 14.000000 ######
    5.                      "股票分红" 14.000000 ######
    6.                              "每小时工资" 12.000000 #####
    7.                      "出生国家" 11.000000 #####
    8.                 "详细职业重编码" 11.000000 #####
    9.                       "一年工作周数" 11.000000 #####
   10.                                        "年龄" 10.000000 ####
   11.                "前居住州" 10.000000 ####
   12. "为退伍军人事务部填写收入问卷"  9.000000 ####
   13.                            "工人类别"  8.000000 ###
   14.          "全职或兼职就业状态"  8.000000 ###
   15.                               "婚姻状况"  8.000000 ###
   16.              "自营职业或自雇"  8.000000 ###
   17.                                        "性别"  6.000000 ##
   18.                             "报税状态"  5.000000 ##
   19.                    "父亲出生国家"  4.000000 #
   20.                                       "种族"  3.000000 #
   21.                   "详细行业重编码"  2.000000 
   22.                            "西班牙裔来源"  2.000000 
   23.                    "母亲出生国家"  1.000000 
   24.            "为雇主工作的人数"  1.000000 
   25.                    "失业原因"  1.000000 
变量重要性: NUM_NODES:
    1.                 "detailed_occupation_recode" 785.000000 ################
    2.                   "detailed_industry_recode" 668.000000 #############
    3.                              "capital_gains" 275.000000 #####
    4.                      "dividends_from_stocks" 220.000000 ####
    5.                             "capital_losses" 197.000000 ####
    6.                                  "education" 178.000000 ###
    7.                    "country_of_birth_mother" 128.000000 ##
    8.                    "country_of_birth_father" 116.000000 ##
    9.                                        "age" 114.000000 ##
   10.                              "wage_per_hour" 98.000000 #
   11.                "state_of_previous_residence" 95.000000 #
   12.         "detailed_household_and_family_stat" 78.000000 #
   13.                            "class_of_worker" 67.000000 #
   14.                      "country_of_birth_self" 65.000000 #
   15.                                        "sex" 65.000000 #
   16.                       "weeks_worked_in_year" 60.000000 #
   17.                             "tax_filer_stat" 57.000000 #
   18.            "num_persons_worked_for_employer" 54.000000 #
   19.              "own_business_or_self_employed" 30.000000 
   20.                               "marital_stat" 26.000000 
   21.                    "member_of_a_labor_union" 16.000000 
   22. "fill_inc_questionnaire_for_veteran's_admin" 15.000000 
   23.          "full_or_part_time_employment_stat" 15.000000 
   24.                        "major_industry_code" 15.000000 
   25.                            "hispanic_origin"  9.000000 
   26.                      "major_occupation_code"  7.000000 
   27.                                       "race"  7.000000 
   28.                                "citizenship"  1.000000 
   29.    "detailed_household_summary_in_household"  1.000000 
   30.               "migration_code-change_in_reg"  1.000000 
   31.              "migration_prev_res_in_sunbelt"  1.000000 
   32.                    "reason_for_unemployment"  1.000000 
   33.               "region_of_previous_residence"  1.000000 
   34.                          "veterans_benefits"  1.000000 
变量重要性: SUM_SCORE:
    1.                 "detailed_occupation_recode" 15392441.075369 ################
    2.                              "capital_gains" 5277826.822514 #####
    3.                                  "education" 4751749.289550 ####
    4.                      "dividends_from_stocks" 3792002.951255 ###
    5.                   "detailed_industry_recode" 2882200.882109 ##
    6.                                        "sex" 2559417.877325 ##
    7.                                        "age" 2042990.944829 ##
    8.                             "capital_losses" 1735728.772551 #
    9.                       "weeks_worked_in_year" 1272820.203971 #
   10.                             "tax_filer_stat" 697890.160846 
   11.            "num_persons_worked_for_employer" 671351.905595 
   12.         "detailed_household_and_family_stat" 444620.829557 
   13.                            "class_of_worker" 362250.565331 
   14.                    "country_of_birth_mother" 296311.574426 
   15.                    "country_of_birth_father" 258198.889206 
   16.                              "wage_per_hour" 239764.219048 
   17.                "state_of_previous_residence" 237687.602572 
   18.                      "country_of_birth_self" 103002.168158 
   19.                               "marital_stat" 102449.735314 
   20.              "own_business_or_self_employed" 82938.893541 
   21. "fill_inc_questionnaire_for_veteran's_admin" 22692.700206 
   22.          "full_or_part_time_employment_stat" 19078.398837 
   23.                        "major_industry_code" 18450.345505 
   24.                    "member_of_a_labor_union" 14905.360879 
   25.                            "hispanic_origin" 12602.867902 
   26.                      "major_occupation_code" 8709.665989 
   27.                                       "race" 6116.282065 
   28.                                "citizenship" 3291.490393 
   29.    "detailed_household_summary_in_household" 2733.439375 
   30.                          "veterans_benefits" 1230.940488 
   31.               "region_of_previous_residence" 1139.240981 
   32.                    "reason_for_unemployment" 219.245124 
   33.               "migration_code-change_in_reg" 55.806436 
   34.              "migration_prev_res_in_sunbelt" 37.780635 
损失: BINOMIAL_LOG_LIKELIHOOD
验证损失值: 0.228983
每次迭代的树的数量: 1
节点格式: NOT_SET
树的总数: 245
节点总数: 7179
树的节点数:
计数:245 平均:29.302 标准差:2.96211
最小:17 最大:31 忽略:0
----------------------------------------------
[ 17, 18)   2   0.82%   0.82%
[ 18, 19)   0   0.00%   0.82%
[ 19, 20)   3   1.22%   2.04%
[ 20, 21)   0   0.00%   2.04%
[ 21, 22)   4   1.63%   3.67%
[ 22, 23)   0   0.00%   3.67%
[ 23, 24)  15   6.12%   9.80% #
[ 24, 25)   0   0.00%   9.80%
[ 25, 26)   5   2.04%  11.84%
[ 26, 27)   0   0.00%  11.84%
[ 27, 28)  21   8.57%  20.41% #
[ 28, 29)   0   0.00%  20.41%
[ 29, 30)  39  15.92%  36.33% ###
[ 30, 31)   0   0.00%  36.33%
[ 31, 31] 156  63.67% 100.00% ##########
叶子节点的深度:
计数:3712 平均:3.95259 标准差:0.249814
最小:2 最大:4 忽略:0
----------------------------------------------
[ 2, 3)   32   0.86%   0.86%
[ 3, 4)  112   3.02%   3.88%
[ 4, 4] 3568  96.12% 100.00% ##########
按叶子节点的训练观察数量:
计数:3712 平均:11849.3 标准差:33719.3
最小:6 最大:179360 忽略:0
----------------------------------------------
[      6,   8973) 3100  83.51%  83.51% ##########
[   8973,  17941)  148   3.99%  87.50%
[  17941,  26909)   79   2.13%  89.63%
[  26909,  35877)   36   0.97%  90.60%
[  35877,  44844)   44   1.19%  91.78%
[  44844,  53812)   17   0.46%  92.24%
[  53812,  62780)   20   0.54%  92.78%
[  62780,  71748)   39   1.05%  93.83%
[  71748,  80715)   24   0.65%  94.48%
[  80715,  89683)   12   0.32%  94.80%
[  89683,  98651)   22   0.59%  95.39%
[  98651, 107619)   21   0.57%  95.96%
[ 107619, 116586)   17   0.46%  96.42%
[ 116586, 125554)   17   0.46%  96.88%
[ 125554, 134522)   13   0.35%  97.23%
[ 134522, 143490)    8   0.22%  97.44%
[ 143490, 152457)    5   0.13%  97.58%
[ 152457, 161425)    6   0.16%  97.74%
[ 161425, 170393)   15   0.40%  98.14%
[ 170393, 179360]   69   1.86% 100.00%
节点中的属性:
785 : detailed_occupation_recode [分类]
668 : detailed_industry_recode [分类]
275 : capital_gains [数值]
220 : dividends_from_stocks [数值]
197 : capital_losses [数值]
178 : education [分类]
128 : country_of_birth_mother [分类]
116 : country_of_birth_father [分类]
114 : age [数值]
98 : wage_per_hour [数值]
95 : state_of_previous_residence [分类]
78 : detailed_household_and_family_stat [分类]
67 : class_of_worker [分类]
65 : sex [分类]
65 : country_of_birth_self [分类]
60 : weeks_worked_in_year [数值]
57 : tax_filer_stat [分类]
54 : num_persons_worked_for_employer [数值]
30 : own_business_or_self_employed [分类]
26 : marital_stat [分类]
16 : member_of_a_labor_union [分类]
15 : major_industry_code [分类]
15 : full_or_part_time_employment_stat [分类]
15 : fill_inc_questionnaire_for_veteran's_admin [分类]
9 : hispanic_origin [分类]
7 : race [分类]
7 : major_occupation_code [分类]
1 : veterans_benefits [分类]
1 : region_of_previous_residence [分类]
1 : reason_for_unemployment [分类]
1 : migration_prev_res_in_sunbelt [分类]
1 : migration_code-change_in_reg [分类]
1 : detailed_household_summary_in_household [分类]
1 : citizenship [分类]
深度 <= 0 的节点中的属性:
33 : education [分类]
29 : capital_gains [数值]
24 : capital_losses [数值]
14 : dividends_from_stocks [数值]
14 : detailed_household_and_family_stat [分类]
12 : wage_per_hour [数值]
11 : weeks_worked_in_year [数值]
11 : detailed_occupation_recode [分类]
11 : country_of_birth_self [分类]
10 : state_of_previous_residence [分类]
10 : age [数值]
9 : fill_inc_questionnaire_for_veteran's_admin [分类]
8 : own_business_or_self_employed [分类]
8 : marital_stat [分类]
8 : full_or_part_time_employment_stat [分类]
8 : class_of_worker [分类]
6 : sex [分类]
5 : tax_filer_stat [分类]
4 : country_of_birth_father [分类]
3 : race [分类]
2 : hispanic_origin [分类]
2 : detailed_industry_recode [分类]
1 : reason_for_unemployment [分类]
1 : num_persons_worked_for_employer [数值]
1 : country_of_birth_mother [分类]
深度 <= 1 的节点中的属性:
140 : detailed_occupation_recode [分类]
82 : capital_gains [数值]
65 : capital_losses [数值]
62 : education [分类]
59 : detailed_industry_recode [分类]
47 : dividends_from_stocks [数值]
31 : wage_per_hour [数值]
26 : detailed_household_and_family_stat [分类]
23 : age [数值]
22 : state_of_previous_residence [分类]
21 : country_of_birth_self [分类]
21 : class_of_worker [分类]
20 : weeks_worked_in_year [数值]
20 : sex [分类]
15 : country_of_birth_father [分类]
12 : own_business_or_self_employed [分类]
11 : fill_inc_questionnaire_for_veteran's_admin [分类]
10 : num_persons_worked_for_employer [数值]
9 : tax_filer_stat [分类]
9 : full_or_part_time_employment_stat [分类]
8 : marital_stat [分类]
8 : country_of_birth_mother [分类]
6 : member_of_a_labor_union [分类]
5 : race [分类]
2 : hispanic_origin [分类]
1 : reason_for_unemployment [分类]
节点中深度 <= 2 的属性:
399 : detailed_occupation_recode [分类]
249 : detailed_industry_recode [分类]
170 : capital_gains [数值]
117 : dividends_from_stocks [数值]
116 : capital_losses [数值]
87 : education [分类]
59 : wage_per_hour [数值]
45 : detailed_household_and_family_stat [分类]
43 : country_of_birth_father [分类]
43 : age [数值]
40 : country_of_birth_self [分类]
38 : state_of_previous_residence [分类]
38 : class_of_worker [分类]
37 : sex [分类]
36 : weeks_worked_in_year [数值]
33 : country_of_birth_mother [分类]
28 : num_persons_worked_for_employer [数值]
26 : tax_filer_stat [分类]
14 : own_business_or_self_employed [分类]
14 : marital_stat [分类]
12 : full_or_part_time_employment_stat [分类]
12 : fill_inc_questionnaire_for_veteran's_admin [分类]
8 : member_of_a_labor_union [分类]
6 : race [分类]
6 : hispanic_origin [分类]
2 : major_occupation_code [分类]
2 : major_industry_code [分类]
1 : reason_for_unemployment [分类]
1 : migration_prev_res_in_sunbelt [分类]
1 : migration_code-change_in_reg [分类]
节点深度 <= 3 的属性:
785 : detailed_occupation_recode [分类]
668 : detailed_industry_recode [分类]
275 : capital_gains [数值]
220 : dividends_from_stocks [数值]
197 : capital_losses [数值]
178 : education [分类]
128 : country_of_birth_mother [分类]
116 : country_of_birth_father [分类]
114 : age [数值]
98 : wage_per_hour [数值]
95 : state_of_previous_residence [分类]
78 : detailed_household_and_family_stat [分类]
67 : class_of_worker [分类]
65 : sex [分类]
65 : country_of_birth_self [分类]
60 : weeks_worked_in_year [数值]
57 : tax_filer_stat [分类]
54 : num_persons_worked_for_employer [数值]
30 : own_business_or_self_employed [分类]
26 : marital_stat [分类]
16 : member_of_a_labor_union [分类]
15 : major_industry_code [分类]
15 : full_or_part_time_employment_stat [分类]
15 : fill_inc_questionnaire_for_veteran's_admin [分类]
9 : hispanic_origin [分类]
7 : race [分类]
7 : major_occupation_code [分类]
1 : veterans_benefits [分类]
1 : region_of_previous_residence [分类]
1 : reason_for_unemployment [分类]
1 : migration_prev_res_in_sunbelt [分类]
1 : migration_code-change_in_reg [分类]
1 : detailed_household_summary_in_household [分类]
1 : citizenship [分类]
节点深度 <= 5 的属性:
785 : detailed_occupation_recode [分类]
668 : detailed_industry_recode [分类]
275 : capital_gains [数值]
220 : dividends_from_stocks [数值]
197 : capital_losses [数值]
178 : education [分类]
128 : country_of_birth_mother [分类]
116 : country_of_birth_father [分类]
114 : age [数值]
98 : wage_per_hour [数值]
95 : state_of_previous_residence [分类]
78 : detailed_household_and_family_stat [分类]
67 : class_of_worker [分类]
65 : sex [分类]
65 : country_of_birth_self [分类]
60 : weeks_worked_in_year [数值]
57 : tax_filer_stat [分类]
54 : num_persons_worked_for_employer [数值]
30 : own_business_or_self_employed [分类]
26 : marital_stat [分类]
16 : member_of_a_labor_union [分类]
15 : major_industry_code [分类]
15 : full_or_part_time_employment_stat [分类]
15 : fill_inc_questionnaire_for_veteran's_admin [分类]
9 : hispanic_origin [分类]
7 : race [分类]
7 : major_occupation_code [分类]
1 : veterans_benefits [分类]
1 : region_of_previous_residence [分类]
1 : reason_for_unemployment [分类]
1 : migration_prev_res_in_sunbelt [分类]
1 : migration_code-change_in_reg [分类]
1 : detailed_household_summary_in_household [分类]
1 : citizenship [分类]
节点中的条件类型:
2418 : ContainsBitmapCondition
1018 : HigherCondition
31 : ContainsCondition
节点深度 <= 0 的条件类型:
137 : ContainsBitmapCondition
101 : HigherCondition
7 : ContainsCondition
节点深度 <= 1 的条件类型:
448 : ContainsBitmapCondition
278 : HigherCondition
9 : ContainsCondition
节点深度 <= 2 的条件类型:
1097 : ContainsBitmapCondition
569 : HigherCondition
17 : ContainsCondition
节点深度 <= 3 的条件类型:
2418 : ContainsBitmapCondition
1018 : HigherCondition
31 : ContainsCondition
节点深度 <= 5 的条件类型:
2418 : ContainsBitmapCondition
1018 : HigherCondition
31 : ContainsCondition

实验 2:决策森林与目标编码

目标编码是一种常见的预处理技术, 用于将分类特征转换为数值特征。 直接使用高基数的分类特征可能导致过拟合。 目标编码旨在用一个或多个数值替换每个分类特征值,这些数值表示其与目标标签的共现。

更确切地说,给定一个分类特征,此示例中的二进制目标编码器将生成三个新的数值特征:

  1. positive_frequency: 每个特征值与正目标标签一起出现的次数。
  2. negative_frequency: 每个特征值与负目标标签一起出现的次数。
  3. positive_probability: 给定特征值的情况下,目标标签为正的概率, 其计算公式为positive_frequency / (positive_frequency + negative_frequency + correction)correction 项的添加是为了使得对于稀有分类值的划分更稳定。 correction 的默认值为 1.0。

请注意,目标编码在无法自动处理的模型中是有效的。 学习稠密表示以处理类别特征,例如决策森林或核方法。如果使用神经网络模型,建议将类别特征编码为嵌入。

实现二进制目标编码器

为了简化,我们假设adaptcall方法的输入都是预期的数据类型和形状,因此没有添加验证逻辑。

建议将类别特征的vocabulary_size传递给BinaryTargetEncoding构造函数。如果未指定,它将在adapt()方法执行期间计算。

class BinaryTargetEncoding(layers.Layer):
    def __init__(self, vocabulary_size=None, correction=1.0, **kwargs):
        super().__init__(**kwargs)
        self.vocabulary_size = vocabulary_size
        self.correction = correction

    def adapt(self, data):
        # data预计为一个整数numpy数组,形状为[示例数量, 2]。
        # 这包含数据集中给定特征的特征值和目标值。

        # 将数据转换为张量。
        data = tf.convert_to_tensor(data)
        # 分离特征值和目标值
        feature_values = tf.cast(data[:, 0], tf.dtypes.int32)
        target_values = tf.cast(data[:, 1], tf.dtypes.bool)

        # 计算未指定的vocabulary_size。
        if self.vocabulary_size is None:
            self.vocabulary_size = tf.unique(feature_values).y.shape[0]

        # 过滤目标标签为正的数据显示。
        positive_indices = tf.where(condition=target_values)
        postive_feature_values = tf.gather_nd(
            params=feature_values, indices=positive_indices
        )
        # 计算每个特征值与正目标标签发生的次数。
        positive_frequency = tf.math.unsorted_segment_sum(
            data=tf.ones(
                shape=(postive_feature_values.shape[0], 1), dtype=tf.dtypes.float64
            ),
            segment_ids=postive_feature_values,
            num_segments=self.vocabulary_size,
        )

        # 过滤目标标签为负的数据。
        negative_indices = tf.where(condition=tf.math.logical_not(target_values))
        negative_feature_values = tf.gather_nd(
            params=feature_values, indices=negative_indices
        )
        # 计算每个特征值与负目标标签发生的次数。
        negative_frequency = tf.math.unsorted_segment_sum(
            data=tf.ones(
                shape=(negative_feature_values.shape[0], 1), dtype=tf.dtypes.float64
            ),
            segment_ids=negative_feature_values,
            num_segments=self.vocabulary_size,
        )
        # 计算输入特征值的正概率。
        positive_probability = positive_frequency / (
            positive_frequency + negative_frequency + self.correction
        )
        # 连接计算得到的目标编码统计数据。
        target_encoding_statistics = tf.cast(
            tf.concat(
                [positive_frequency, negative_frequency, positive_probability], axis=1
            ),
            dtype=tf.dtypes.float32,
        )
        self.target_encoding_statistics = tf.constant(target_encoding_statistics)

    def call(self, inputs):
        # inputs预计为一个整数numpy数组,形状为[示例数量, 1]。
        # 这包括数据集中给定特征的特征值。

        # 如果目标编码统计数据未计算,则引发错误。
        if self.target_encoding_statistics == None:
            raise ValueError(
                f"您需要调用adapt方法来计算目标编码统计数据。"
            )

        # 将输入转换为张量。
        inputs = tf.convert_to_tensor(inputs)
        # 将输入强制转换为int64张量。
        inputs = tf.cast(inputs, tf.dtypes.int64)
        # 查找输入特征值的目标编码统计数据。
        target_encoding_statistics = tf.cast(
            tf.gather_nd(self.target_encoding_statistics, inputs),
            dtype=tf.dtypes.float32,
        )
        return target_encoding_statistics

让我们测试二进制目标编码器

data = tf.constant(
    [
        [0, 1],
        [2, 0],
        [0, 1],
        [1, 1],
        [1, 1],
        [2, 0],
        [1, 0],
        [0, 1],
        [2, 1],
        [1, 0],
        [0, 1],
        [2, 0],
        [0, 1],
        [1, 1],
        [1, 1],
        [2, 0],
        [1, 0],
        [0, 1],
        [2, 0],
    ]
)

binary_target_encoder = BinaryTargetEncoding()
binary_target_encoder.adapt(data)
print(binary_target_encoder([[0], [1], [2]]))
tf.Tensor(
[[6.         0.         0.85714287]
 [4.         3.         0.5       ]
 [1.         5.         0.14285715]], shape=(3, 3), dtype=float32)

创建模型输入

def create_model_inputs():
    inputs = {}

    for feature_name in NUMERIC_FEATURE_NAMES:
        inputs[feature_name] = layers.Input(
            name=feature_name, shape=(), dtype=tf.float32
        )

    for feature_name in CATEGORICAL_FEATURE_NAMES:
        inputs[feature_name] = layers.Input(
            name=feature_name, shape=(), dtype=tf.string
        )

    return inputs

实现带有目标编码的特征编码

def create_target_encoder():
    inputs = create_model_inputs()
    target_values = train_data[[TARGET_COLUMN_NAME]].to_numpy()
    encoded_features = []
    for feature_name in inputs:
        if feature_name in CATEGORICAL_FEATURE_NAMES:
            # 获取分类特征的词汇。
            vocabulary = sorted(
                [str(value) for value in list(train_data[feature_name].unique())]
            )
            # 创建一个查找表,将字符串值转换为整数索引。
            # 由于我们没有使用掩码标记,也不期待任何 OOV(超出词汇表)标记, 
            # 所以我们将 mask_token 设置为 None,将 num_oov_indices 设置为 0。
            lookup = layers.StringLookup(
                vocabulary=vocabulary, mask_token=None, num_oov_indices=0
            )
            # 将字符串输入值转换为整数索引。
            value_indices = lookup(inputs[feature_name])
            # 准备数据以适应目标编码。
            print("### 正在适应目标编码,特征名为:", feature_name)
            feature_values = train_data[[feature_name]].to_numpy().astype(str)
            feature_value_indices = lookup(feature_values)
            data = tf.concat([feature_value_indices, target_values], axis=1)
            feature_encoder = BinaryTargetEncoding()
            feature_encoder.adapt(data)
            # 将特征值索引转换为目标编码表示。
            encoded_feature = feature_encoder(tf.expand_dims(value_indices, -1))
        else:
            # 扩展数值输入特征的维度,按原样使用。
            encoded_feature = tf.expand_dims(inputs[feature_name], -1)
        # 将编码特征添加到列表中。
        encoded_features.append(encoded_feature)
    # 将所有编码特征连接起来。
    encoded_features = tf.concat(encoded_features, axis=1)
    # 创建并返回一个以编码特征为输出的 Keras 模型。
    return keras.Model(inputs=inputs, outputs=encoded_features)

使用预处理器创建梯度提升树模型

在这种情况下,我们将目标编码用作梯度提升树模型的预处理器, 并让模型推断输入特征的语义。

def create_gbt_with_preprocessor(preprocessor):

    gbt_model = tfdf.keras.GradientBoostedTreesModel(
        preprocessing=preprocessor,
        num_trees=NUM_TREES,
        max_depth=MAX_DEPTH,
        min_examples=MIN_EXAMPLES,
        subsample=SUBSAMPLE,
        validation_ratio=VALIDATION_RATIO,
        task=tfdf.keras.Task.CLASSIFICATION,
    )

    gbt_model.compile(metrics=[keras.metrics.BinaryAccuracy(name="accuracy")])

    return gbt_model

训练和评估模型

gbt_model = create_gbt_with_preprocessor(create_target_encoder())
run_experiment(gbt_model, train_data, test_data)
### 调整目标编码以便于:class_of_worker
### 调整目标编码以便于:detailed_industry_recode
### 调整目标编码以便于:detailed_occupation_recode
### 调整目标编码以便于:education
### 调整目标编码以便于:enroll_in_edu_inst_last_wk
### 调整目标编码以便于:marital_stat
### 调整目标编码以便于:major_industry_code
### 调整目标编码以便于:major_occupation_code
### 调整目标编码以便于:race
### 调整目标编码以便于:hispanic_origin
### 调整目标编码以便于:sex
### 调整目标编码以便于:member_of_a_labor_union
### 调整目标编码以便于:reason_for_unemployment
### 调整目标编码以便于:full_or_part_time_employment_stat
### 调整目标编码以便于:tax_filer_stat
### 调整目标编码以便于:region_of_previous_residence
### 调整目标编码以便于:state_of_previous_residence
### 调整目标编码以便于:detailed_household_and_family_stat
### 调整目标编码以便于:detailed_household_summary_in_household
### 调整目标编码以便于:migration_code-change_in_msa
### 调整目标编码以便于:migration_code-change_in_reg
### 调整目标编码以便于:migration_code-move_within_reg
### 调整目标编码以便于:live_in_this_house_1_year_ago
### 调整目标编码以便于:migration_prev_res_in_sunbelt
### 调整目标编码以便于:family_members_under_18
### 调整目标编码以便于:country_of_birth_father
### 调整目标编码以便于:country_of_birth_mother
### 调整目标编码以便于:country_of_birth_self
### 调整目标编码以便于:citizenship
### 调整目标编码以便于:own_business_or_self_employed
### 调整目标编码以便于:fill_inc_questionnaire_for_veteran's_admin
### 调整目标编码以便于:veterans_benefits
### 调整目标编码以便于:year
使用 /tmp/tmpj_0h78ld 作为临时训练目录
开始读取数据集
198/200 [============================>.] - ETA: 0s
数据集读取完毕,耗时0:00:06.793717
训练模型
模型训练完毕,耗时0:04:32.752691
编译模型
200/200 [==============================] - 280s 1s/步
测试准确率:95.81%

实验 3:使用训练嵌入的决策森林

在这种情况下,我们构建了一个编码器模型,该模型将类别特征编码为嵌入,其中给定类别特征的嵌入大小为其词汇量的平方根。

我们通过反向传播在一个简单的神经网络模型中训练这些嵌入。 嵌入编码器训练完成后,我们将其用作梯度提升树模型输入特征的预处理器。

请注意,嵌入和决策森林模型无法在一个阶段协同训练,因为决策森林模型无法通过反向传播进行训练。 相反,嵌入必须在初始阶段训练, 然后用作决策森林模型的静态输入。

使用嵌入实现特征编码

def create_embedding_encoder(size=None):
    inputs = create_model_inputs()
    encoded_features = []
    for feature_name in inputs:
        if feature_name in CATEGORICAL_FEATURE_NAMES:
            # 获取类别特征的词汇。
            vocabulary = sorted(
                [str(value) for value in list(train_data[feature_name].unique())]
            )
            # 创建查找表,将字符串值转换为整数索引。
            # 因为我们不使用掩码标记,也不期望任何词汇外
            # (oov) 标记,所以将 mask_token 设置为 None,num_oov_indices 设置为 0。
            lookup = layers.StringLookup(
                vocabulary=vocabulary, mask_token=None, num_oov_indices=0
            )
            # 将字符串输入值转换为整数索引。
            value_index = lookup(inputs[feature_name])
            # 创建一个具有指定维度的嵌入层
            vocabulary_size = len(vocabulary)
            embedding_size = int(math.sqrt(vocabulary_size))
            feature_encoder = layers.Embedding(
                input_dim=len(vocabulary), output_dim=embedding_size
            )
            # 将索引值转换为嵌入表示。
            encoded_feature = feature_encoder(value_index)
        else:
            # 扩展数值输入特征的维度并原样使用。
            encoded_feature = tf.expand_dims(inputs[feature_name], -1)
        # 将编码特征添加到列表中。
        encoded_features.append(encoded_feature)
    # 连接所有编码特征。
    encoded_features = layers.concatenate(encoded_features, axis=1)
    # 应用 dropout。
    encoded_features = layers.Dropout(rate=0.25)(encoded_features)
    # 执行非线性投影。
    encoded_features = layers.Dense(
        units=size if size else encoded_features.shape[-1], activation="gelu"
    )(encoded_features)
    # 创建并返回一个以编码特征为输出的 Keras 模型。
    return keras.Model(inputs=inputs, outputs=encoded_features)

构建一个神经网络模型以训练嵌入

def create_nn_model(encoder):
    inputs = create_model_inputs()
    embeddings = encoder(inputs)
    output = layers.Dense(units=1, activation="sigmoid")(embeddings)

    nn_model = keras.Model(inputs=inputs, outputs=output)
    nn_model.compile(
        optimizer=keras.optimizers.Adam(),
        loss=keras.losses.BinaryCrossentropy(),
        metrics=[keras.metrics.BinaryAccuracy("accuracy")],
    )
    return nn_model


embedding_encoder = create_embedding_encoder(size=64)
run_experiment(
    create_nn_model(embedding_encoder),
    train_data,
    test_data,
    num_epochs=5,
    batch_size=256,
)
Epoch 1/5
200/200 [==============================] - 10s 27ms/step - loss: 8303.1455 - accuracy: 0.9193
Epoch 2/5
200/200 [==============================] - 5s 27ms/step - loss: 1019.4900 - accuracy: 0.9371
Epoch 3/5
200/200 [==============================] - 5s 27ms/step - loss: 612.2844 - accuracy: 0.9416
Epoch 4/5
200/200 [==============================] - 5s 27ms/step - loss: 858.9774 - accuracy: 0.9397
Epoch 5/5
200/200 [==============================] - 5s 26ms/step - loss: 842.3922 - accuracy: 0.9421
测试准确率: 95.0%

使用嵌入训练和评估梯度提升树模型

gbt_model = create_gbt_with_preprocessor(embedding_encoder)
run_experiment(gbt_model, train_data, test_data)
使用 /tmp/tmpao5o88p6 作为临时训练目录
开始读取数据集
199/200 [============================>.] - ETA: 0s
数据集读取时间为 0:00:06.722677
训练模型
模型训练时间为 0:05:18.350298
编译模型
200/200 [==============================] - 325s 2s/step
测试准确率: 95.82%

结束语

TensorFlow 决策森林提供了强大的模型,特别是对于结构化数据。 在我们的实验中,梯度提升树模型达到了 95.79% 的测试准确率。 在使用目标编码处理分类特征时,相同的模型实现了95.81%的测试准确率。 在预训练嵌入以供梯度提升树模型使用时,我们达到了95.82%的测试准确率。

决策森林可以与神经网络结合使用,方式有: 1) 使用神经网络学习输入数据的有用表示,然后使用决策森林进行监督学习任务,或 2) 创建决策森林和神经网络模型的集成。

请注意,TensorFlow决策森林目前不支持硬件加速器。 所有训练和推理均在CPU上完成。 此外,决策森林在其训练过程中需要一个有限的数据集,该数据集能够适应内存。然而,增加数据集的大小所带来的收益递减,并且决策森林算法在收敛方面所需的示例数量可能少于大型神经网络模型。

介绍
设置
结束语