备注
前往结尾 下载完整示例代码。
使用回调与Dask的示例
from typing import Any
import numpy as np
from dask.distributed import Client, LocalCluster
from dask_ml.datasets import make_regression
from dask_ml.model_selection import train_test_split
import xgboost as xgb
import xgboost.dask as dxgb
from xgboost.dask import DaskDMatrix
def probability_for_going_backward(epoch: int) -> float:
return 0.999 / (1.0 + 0.05 * np.log(1.0 + epoch))
# All callback functions must inherit from TrainingCallback
class CustomEarlyStopping(xgb.callback.TrainingCallback):
"""A custom early stopping class where early stopping is determined stochastically.
In the beginning, allow the metric to become worse with a probability of 0.999.
As boosting progresses, the probability should be adjusted downward"""
def __init__(
self, *, validation_set: str, target_metric: str, maximize: bool, seed: int
) -> None:
self.validation_set = validation_set
self.target_metric = target_metric
self.maximize = maximize
self.seed = seed
self.rng = np.random.default_rng(seed=seed)
if maximize:
self.better = lambda x, y: x > y
else:
self.better = lambda x, y: x < y
def after_iteration(
self, model: Any, epoch: int, evals_log: xgb.callback.TrainingCallback.EvalsLog
) -> bool:
metric_history = evals_log[self.validation_set][self.target_metric]
if len(metric_history) < 2 or self.better(
metric_history[-1], metric_history[-2]
):
return False # continue training
p = probability_for_going_backward(epoch)
go_backward = self.rng.choice(2, size=(1,), replace=True, p=[1 - p, p]).astype(
np.bool_
)[0]
print(
"The validation metric went into the wrong direction. "
+ f"Stopping training with probability {1 - p}..."
)
if go_backward:
return False # continue training
else:
return True # stop training
def main(client: Client) -> None:
m = 100000
n = 100
X, y = make_regression(n_samples=m, n_features=n, chunks=200, random_state=0)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
dtrain = DaskDMatrix(client, X_train, y_train)
dtest = DaskDMatrix(client, X_test, y_test)
output = dxgb.train(
client,
{
"verbosity": 1,
"tree_method": "hist",
"objective": "reg:squarederror",
"eval_metric": "rmse",
"max_depth": 6,
"learning_rate": 1.0,
},
dtrain,
num_boost_round=1000,
evals=[(dtrain, "train"), (dtest, "test")],
callbacks=[
CustomEarlyStopping(
validation_set="test", target_metric="rmse", maximize=False, seed=0
)
],
)
if __name__ == "__main__":
# or use other clusters for scaling
with LocalCluster(n_workers=4, threads_per_worker=1) as cluster:
with Client(cluster) as client:
main(client)