API参考
本页面记录了YDF的Python API。用户还可以使用C++和CLI API训练模型。
学习器
学习器用于训练模型,并且可以进行交叉验证。
- GradientBoostedTreesLearner
- RandomForestLearner
- CartLearner
- DecisionTreeLearner: 别名为CartLearner。
- DistributedGradientBoostedTreesLearner
- IsolationForestLearner
所有学习器都继承自GenericLearner。
模型
模型用于进行预测并可以进行评估。
注意: 模型(例如,GradientBoostedTreesModel
)不包含训练功能。要训练模型,您需要创建一个学习器(例如,GradientBoostedTreesLearner
)。训练超参数是学习器类的构造函数参数。
所有模型都继承自GenericModel。
调优器
调优器通过重复训练和评估来找到最佳的超参数集。
- RandomSearchTuner
- VizierTuner(目前仅限Google员工使用)
其他
- verbose: 控制日志记录的详细程度。
- load_model: 从磁盘加载模型。
- Feature: 输入特征特定的超参数,例如语义、约束。
- Column:
Feature
的别名。 - Task: 指定模型解决的任务,例如分类。
- Semantic: 输入特征的解释方式,例如数值、分类。
- start_worker: 启动分布式训练的工作节点。
- strict: 显示更多日志。
实用工具
- ydf.util.read_tf_record: 将TF Record数据集读入内存。
- ydf.util.write_tf_record: 将TF Record数据集从内存写入磁盘。
高级实用工具
- ModelIOOptions: 将模型保存到磁盘的选项。
- create_vertical_dataset: 将数据集加载到内存中。
- ModelMetadata: 关于模型的元数据,例如训练日期、唯一标识符。
- from_tensorflow_decision_forests: 从磁盘加载TensorFlow决策森林模型。
- from_sklearn: 将scikit-learn模型转换为YDF模型。
- NodeFormat: 用于序列化树节点的格式。
自定义损失
- RegressionLoss: 回归任务的自定义损失。
- BinaryClassificationLoss: 二分类任务的自定义损失。
- MultiClassificationLoss: 多分类任务的自定义损失。
- Activation: 自定义损失的激活(即链接)函数的集合。
树
ydf.tree.*
类提供了对树结构、叶子和值的编程读写访问。
- tree.Tree: 由
model.get_tree(...)
和model.set_tree(...)
返回和使用的决策树。
条件
- tree.AbstractCondition: 基础条件类。
- tree.NumericalHigherThanCondition: 形式为
attribute >= threshold
的条件。 - tree.CategoricalIsInCondition: 形式为
attribute in mask
的条件。 - tree.CategoricalSetContainsCondition: 形式为
attribute intersect mask != empty
的条件。 - tree.DiscretizedNumericalHigherThanCondition: 形式为
attribute >= bounds[threshold]
的条件。 - tree.IsMissingInCondition: 形式为
attribute is missing
的条件。 - tree.IsTrueCondition: 形式为
attribute is true
的条件。 - tree.NumericalSparseObliqueCondition: 形式为
sum(attributes[i] * weights[i]) >= threshold
的条件。
节点
- tree.AbstractNode: 基础节点类。
- tree.Leaf: 包含一个值的叶子节点。
- tree.NonLeaf: 包含条件的非叶子节点。
值
- tree.AbstractValue: 基础值类。
- tree.ProbabilityValue: 概率分布值。
- tree.Leaf: 回归树的回归值。
- tree.UpliftValue: 分类或回归提升树的提升值。