ray.train.xgboost.RayTrainReportCallback#
- class ray.train.xgboost.RayTrainReportCallback(*args: Any, **kwargs: Any)[源代码]#
基类:
TuneCallback
XGBoost 回调函数用于保存检查点和报告指标。
- 参数:
metrics – 要报告的指标。如果这是一个列表,每个项目描述了报告给 XGBoost 的指标键,并且它将以相同名称报告。这也可以是一个 {<要报告的键>: <xgboost-metric-key>} 的字典,用于重命名 XGBoost 默认指标。
filename – 通过传递文件名来自定义保存的检查点文件类型。默认为“model.ubj”。
frequency – 以迭代次数为单位,保存检查点的频率。默认为 0(训练期间不保存检查点)。
checkpoint_at_end – 是否在训练结束时保存检查点。
results_postprocessing_fn – 一个可选的可调用对象,它接收将要报告的指标字典(在它被展平之后)并返回一个修改后的字典。例如,当使用
xgboost.cv
时,这可以用于平均跨CV折叠的结果。
示例
在运行许多独立的 xgboost 试验(试验内部没有数据并行)时,向 Ray Tune 报告检查点和指标。
import xgboost from ray.tune import Tuner from ray.train.xgboost import RayTrainReportCallback def train_fn(config): # Report log loss to Ray Tune after each validation epoch. bst = xgboost.train( ..., callbacks=[ RayTrainReportCallback( metrics={"loss": "eval-logloss"}, frequency=1 ) ], ) tuner = Tuner(train_fn) results = tuner.fit()
从该回调报告的检查点加载模型。
from ray.train.xgboost import RayTrainReportCallback # Get a `Checkpoint` object that is saved by the callback during training. result = trainer.fit() booster = RayTrainReportCallback.get_model(result.checkpoint)
PublicAPI (测试版): 此API目前处于测试阶段,在成为稳定版本之前可能会发生变化。
方法
检索此回调报告的检查点中存储的模型。
属性