ray.train.xgboost.RayTrainReportCallback#

class ray.train.xgboost.RayTrainReportCallback(*args: Any, **kwargs: Any)[源代码]#

基类:TuneCallback

XGBoost 回调函数用于保存检查点和报告指标。

参数:
  • metrics – 要报告的指标。如果这是一个列表,每个项目描述了报告给 XGBoost 的指标键,并且它将以相同名称报告。这也可以是一个 {<要报告的键>: <xgboost-metric-key>} 的字典,用于重命名 XGBoost 默认指标。

  • filename – 通过传递文件名来自定义保存的检查点文件类型。默认为“model.ubj”。

  • frequency – 以迭代次数为单位,保存检查点的频率。默认为 0(训练期间不保存检查点)。

  • checkpoint_at_end – 是否在训练结束时保存检查点。

  • results_postprocessing_fn – 一个可选的可调用对象,它接收将要报告的指标字典(在它被展平之后)并返回一个修改后的字典。例如,当使用 xgboost.cv 时,这可以用于平均跨CV折叠的结果。

示例

在运行许多独立的 xgboost 试验(试验内部没有数据并行)时,向 Ray Tune 报告检查点和指标。

import xgboost

from ray.tune import Tuner
from ray.train.xgboost import RayTrainReportCallback

def train_fn(config):
    # Report log loss to Ray Tune after each validation epoch.
    bst = xgboost.train(
        ...,
        callbacks=[
            RayTrainReportCallback(
                metrics={"loss": "eval-logloss"}, frequency=1
            )
        ],
    )

tuner = Tuner(train_fn)
results = tuner.fit()

从该回调报告的检查点加载模型。

from ray.train.xgboost import RayTrainReportCallback

# Get a `Checkpoint` object that is saved by the callback during training.
result = trainer.fit()
booster = RayTrainReportCallback.get_model(result.checkpoint)

PublicAPI (测试版): 此API目前处于测试阶段,在成为稳定版本之前可能会发生变化。

方法

get_model

检索此回调报告的检查点中存储的模型。

属性

CHECKPOINT_NAME