回调函数
本文档提供了对 XGBoost Python 包中使用的 回调 API 的基本介绍。在 XGBoost 1.3 中,为 Python 包设计了一个新的回调接口,该接口提供了设计各种训练扩展的灵活性。此外,XGBoost 还提供了许多预定义的回调,用于支持提前停止、检查点等功能。
使用内置回调
默认情况下,XGBoost 的训练方法有诸如 early_stopping_rounds
和 verbose
/verbose_eval
这样的参数,当指定这些参数时,训练过程会在内部定义相应的回调。例如,当指定 early_stopping_rounds
时,EarlyStopping
回调会在迭代循环内部被调用。你也可以直接将此回调函数传递给 XGBoost:
D_train = xgb.DMatrix(X_train, y_train)
D_valid = xgb.DMatrix(X_valid, y_valid)
# Define a custom evaluation metric used for early stopping.
def eval_error_metric(predt, dtrain: xgb.DMatrix):
label = dtrain.get_label()
r = np.zeros(predt.shape)
gt = predt > 0.5
r[gt] = 1 - label[gt]
le = predt <= 0.5
r[le] = label[le]
return 'CustomErr', np.sum(r)
# Specify which dataset and which metric should be used for early stopping.
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
metric_name='CustomErr',
data_name='Valid')
booster = xgb.train(
{'objective': 'binary:logistic',
'eval_metric': ['error', 'rmse'],
'tree_method': 'hist'}, D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
feval=eval_error_metric,
num_boost_round=1000,
callbacks=[early_stop],
verbose_eval=False)
dump = booster.get_dump(dump_format='json')
assert len(early_stop.stopping_history['Valid']['CustomErr']) == len(dump)
定义你自己的回调
XGBoost 提供了一个回调接口类:TrainingCallback
,用户定义的回调应继承此类并重写相应的方法。在 使用和定义回调函数的演示 中有一个工作示例。