备注
前往结尾 以下载完整示例代码。
训练延续的演示
import os
import pickle
import tempfile
from sklearn.datasets import load_breast_cancer
import xgboost
def training_continuation(tmpdir: str, use_pickle: bool) -> None:
"""Basic training continuation."""
# Train 128 iterations in 1 session
X, y = load_breast_cancer(return_X_y=True)
clf = xgboost.XGBClassifier(n_estimators=128, eval_metric="logloss")
clf.fit(X, y, eval_set=[(X, y)])
print("Total boosted rounds:", clf.get_booster().num_boosted_rounds())
# Train 128 iterations in 2 sessions, with the first one runs for 32 iterations and
# the second one runs for 96 iterations
clf = xgboost.XGBClassifier(n_estimators=32, eval_metric="logloss")
clf.fit(X, y, eval_set=[(X, y)])
assert clf.get_booster().num_boosted_rounds() == 32
# load back the model, this could be a checkpoint
if use_pickle:
path = os.path.join(tmpdir, "model-first-32.pkl")
with open(path, "wb") as fd:
pickle.dump(clf, fd)
with open(path, "rb") as fd:
loaded = pickle.load(fd)
else:
path = os.path.join(tmpdir, "model-first-32.json")
clf.save_model(path)
loaded = xgboost.XGBClassifier()
loaded.load_model(path)
clf = xgboost.XGBClassifier(n_estimators=128 - 32, eval_metric="logloss")
clf.fit(X, y, eval_set=[(X, y)], xgb_model=loaded)
print("Total boosted rounds:", clf.get_booster().num_boosted_rounds())
assert clf.get_booster().num_boosted_rounds() == 128
def training_continuation_early_stop(tmpdir: str, use_pickle: bool) -> None:
"""Training continuation with early stopping."""
early_stopping_rounds = 5
early_stop = xgboost.callback.EarlyStopping(
rounds=early_stopping_rounds, save_best=True
)
n_estimators = 512
X, y = load_breast_cancer(return_X_y=True)
clf = xgboost.XGBClassifier(
n_estimators=n_estimators, eval_metric="logloss", callbacks=[early_stop]
)
clf.fit(X, y, eval_set=[(X, y)])
print("Total boosted rounds:", clf.get_booster().num_boosted_rounds())
best = clf.best_iteration
# Train 512 iterations in 2 sessions, with the first one runs for 128 iterations and
# the second one runs until early stop.
clf = xgboost.XGBClassifier(
n_estimators=128, eval_metric="logloss", callbacks=[early_stop]
)
# Reinitialize the early stop callback
early_stop = xgboost.callback.EarlyStopping(
rounds=early_stopping_rounds, save_best=True
)
clf.set_params(callbacks=[early_stop])
clf.fit(X, y, eval_set=[(X, y)])
assert clf.get_booster().num_boosted_rounds() == 128
# load back the model, this could be a checkpoint
if use_pickle:
path = os.path.join(tmpdir, "model-first-128.pkl")
with open(path, "wb") as fd:
pickle.dump(clf, fd)
with open(path, "rb") as fd:
loaded = pickle.load(fd)
else:
path = os.path.join(tmpdir, "model-first-128.json")
clf.save_model(path)
loaded = xgboost.XGBClassifier()
loaded.load_model(path)
early_stop = xgboost.callback.EarlyStopping(
rounds=early_stopping_rounds, save_best=True
)
clf = xgboost.XGBClassifier(
n_estimators=n_estimators - 128, eval_metric="logloss", callbacks=[early_stop]
)
clf.fit(
X,
y,
eval_set=[(X, y)],
xgb_model=loaded,
)
print("Total boosted rounds:", clf.get_booster().num_boosted_rounds())
assert clf.best_iteration == best
if __name__ == "__main__":
with tempfile.TemporaryDirectory() as tmpdir:
training_continuation_early_stop(tmpdir, False)
training_continuation_early_stop(tmpdir, True)
training_continuation(tmpdir, True)
training_continuation(tmpdir, False)