ray.train.lightgbm.RayTrainReportCallback#

class ray.train.lightgbm.RayTrainReportCallback(metrics: str | List[str] | Dict[str, str] | None = None, filename: str = 'model.txt', frequency: int = 0, checkpoint_at_end: bool = True, results_postprocessing_fn: Callable[[Dict[str, float | List[float]]], Dict[str, float]] | None = None)[源代码]#

基类:object

创建一个回调函数,用于报告指标并检查模型。

参数:
  • metrics – 要报告的指标。如果这是一个列表,每个项目应为 LightGBM 报告的指标键,并且它将以相同名称报告给 Ray Train/Tune。这也可以是一个字典,形式为 {<要报告的键>: <lightgbm-metric-key>},用于重命名 LightGBM 的默认指标。

  • filename – 通过传递文件名来自定义保存的检查点文件类型。默认为“model.txt”。

  • frequency – 以迭代次数为单位,保存检查点的频率。默认为 0(训练期间不保存检查点)。

  • checkpoint_at_end – 是否在训练结束时保存检查点。

  • results_postprocessing_fn – 一个可选的可调用对象,它接收将被报告的指标字典(在它被展平之后)并返回一个修改后的字典。

示例

在运行许多独立的 xgboost 试验(试验内部没有数据并行)时,向 Ray Tune 报告检查点和指标。

import lightgbm

from ray.train.lightgbm import RayTrainReportCallback

config = {
    # ...
    "metric": ["binary_logloss", "binary_error"],
}

# Report only log loss to Tune after each validation epoch.
bst = lightgbm.train(
    ...,
    callbacks=[
        RayTrainReportCallback(
            metrics={"loss": "eval-binary_logloss"}, frequency=1
        )
    ],
)

从该回调报告的检查点加载模型。

from ray.train.lightgbm import RayTrainReportCallback

# Get a `Checkpoint` object that is saved by the callback during training.
result = trainer.fit()
booster = RayTrainReportCallback.get_model(result.checkpoint)

PublicAPI (测试版): 此API目前处于测试阶段,在成为稳定版本之前可能会发生变化。

方法

get_model

检索此回调报告的检查点中存储的模型。

属性

CHECKPOINT_NAME