.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/ensemble/plot_gradient_boosting_early_stopping.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_ensemble_plot_gradient_boosting_early_stopping.py: =================================== 梯度提升中的提前停止 =================================== 梯度提升是一种集成技术,它结合了多个弱学习器,通常是决策树,以创建一个健壮且强大的预测模型。它以迭代的方式进行,每个新阶段(树)都会纠正前一个阶段的错误。 提前停止是梯度提升中的一种技术,它允许我们找到构建一个对未见数据具有良好泛化能力并避免过拟合的模型所需的最佳迭代次数。其概念很简单:我们将数据集的一部分作为验证集(使用 `validation_fraction` 指定)来评估模型在训练过程中的表现。随着模型通过增加阶段(树)逐步构建,其在验证集上的表现会随着步骤数的增加而被监控。 当模型在验证集上的表现趋于平稳或变差(在 `tol` 指定的偏差范围内)并持续一定数量的连续阶段(由 `n_iter_no_change` 指定)时,提前停止变得有效。这表明模型已经达到了一个进一步迭代可能导致过拟合的点,此时应停止训练。 应用提前停止时,最终模型中的估计器(树)的数量可以通过 `n_estimators_` 属性访问。总体而言,提前停止是在梯度提升中平衡模型性能和效率的有价值工具。 许可证:BSD 3条款 .. GENERATED FROM PYTHON SOURCE LINES 19-22 数据准备 ---------------- 首先,我们加载并准备加利福尼亚房价数据集以进行训练和评估。它对数据集进行子集化,并将其拆分为训练集和验证集。 .. GENERATED FROM PYTHON SOURCE LINES 22-38 .. code-block:: Python import time import matplotlib.pyplot as plt from sklearn.datasets import fetch_california_housing from sklearn.ensemble import GradientBoostingRegressor from sklearn.metrics import mean_squared_error from sklearn.model_selection import train_test_split data = fetch_california_housing() X, y = data.data[:600], data.target[:600] X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42) .. GENERATED FROM PYTHON SOURCE LINES 39-42 模型训练与比较 ----------------------------- 训练两个 :class:`~sklearn.ensemble.GradientBoostingRegressor` 模型:一个使用早停,另一个不使用早停。目的是比较它们的性能。还计算了训练时间和两个模型使用的 `n_estimators_` 。 .. GENERATED FROM PYTHON SOURCE LINES 42-63 .. code-block:: Python params = dict(n_estimators=1000, max_depth=5, learning_rate=0.1, random_state=42) gbm_full = GradientBoostingRegressor(**params) gbm_early_stopping = GradientBoostingRegressor( **params, validation_fraction=0.1, n_iter_no_change=10, ) start_time = time.time() gbm_full.fit(X_train, y_train) training_time_full = time.time() - start_time n_estimators_full = gbm_full.n_estimators_ start_time = time.time() gbm_early_stopping.fit(X_train, y_train) training_time_early_stopping = time.time() - start_time estimators_early_stopping = gbm_early_stopping.n_estimators_ .. GENERATED FROM PYTHON SOURCE LINES 64-67 错误计算 ----------------- 该代码计算前一节中训练的模型在训练和验证数据集上的:func:`~sklearn.metrics.mean_squared_error` 。它计算每次提升迭代的误差。目的是评估模型的性能和收敛性。 .. GENERATED FROM PYTHON SOURCE LINES 67-93 .. code-block:: Python train_errors_without = [] val_errors_without = [] train_errors_with = [] val_errors_with = [] for i, (train_pred, val_pred) in enumerate( zip( gbm_full.staged_predict(X_train), gbm_full.staged_predict(X_val), ) ): train_errors_without.append(mean_squared_error(y_train, train_pred)) val_errors_without.append(mean_squared_error(y_val, val_pred)) for i, (train_pred, val_pred) in enumerate( zip( gbm_early_stopping.staged_predict(X_train), gbm_early_stopping.staged_predict(X_val), ) ): train_errors_with.append(mean_squared_error(y_train, train_pred)) val_errors_with.append(mean_squared_error(y_val, val_pred)) .. GENERATED FROM PYTHON SOURCE LINES 94-102 可视化比较 -------------------- 它包括三个子图: 1. 绘制两个模型在提升迭代过程中的训练误差。 2. 绘制两个模型在提升迭代过程中的验证误差。 3. 创建柱状图比较有无提前停止的模型的训练时间和使用的估计器。 .. GENERATED FROM PYTHON SOURCE LINES 102-139 .. code-block:: Python fig, axes = plt.subplots(ncols=3, figsize=(12, 4)) axes[0].plot(train_errors_without, label="gbm_full") axes[0].plot(train_errors_with, label="gbm_early_stopping") axes[0].set_xlabel("Boosting Iterations") axes[0].set_ylabel("MSE (Training)") axes[0].set_yscale("log") axes[0].legend() axes[0].set_title("Training Error") axes[1].plot(val_errors_without, label="gbm_full") axes[1].plot(val_errors_with, label="gbm_early_stopping") axes[1].set_xlabel("Boosting Iterations") axes[1].set_ylabel("MSE (Validation)") axes[1].set_yscale("log") axes[1].legend() axes[1].set_title("Validation Error") training_times = [training_time_full, training_time_early_stopping] labels = ["gbm_full", "gbm_early_stopping"] bars = axes[2].bar(labels, training_times) axes[2].set_ylabel("Training Time (s)") for bar, n_estimators in zip(bars, [n_estimators_full, estimators_early_stopping]): height = bar.get_height() axes[2].text( bar.get_x() + bar.get_width() / 2, height + 0.001, f"Estimators: {n_estimators}", ha="center", va="bottom", ) plt.tight_layout() plt.show() .. image-sg:: /auto_examples/ensemble/images/sphx_glr_plot_gradient_boosting_early_stopping_001.png :alt: Training Error, Validation Error :srcset: /auto_examples/ensemble/images/sphx_glr_plot_gradient_boosting_early_stopping_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 140-141 `gbm_full` 和 `gbm_early_stopping` 之间训练误差的差异源于 `gbm_early_stopping` 将训练数据的 `validation_fraction` 部分作为内部验证集。提前停止是基于这个内部验证分数决定的。 .. GENERATED FROM PYTHON SOURCE LINES 144-150 摘要 ------- 在我们使用 :class:`~sklearn.ensemble.GradientBoostingRegressor` 模型对加利福尼亚房价数据集进行的示例中,我们展示了提前停止的实际好处: - **防止过拟合:** 我们展示了验证误差在某个点之后稳定或开始增加,表明模型对未见数据的泛化能力更好。这是通过在过拟合发生之前停止训练过程来实现的。 - **提高训练效率:** 我们比较了使用和不使用提前停止的模型的训练时间。使用提前停止的模型在需要显著更少的估计器的情况下实现了可比的准确性,从而加快了训练速度。 .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 4.398 seconds) .. _sphx_glr_download_auto_examples_ensemble_plot_gradient_boosting_early_stopping.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/main?urlpath=lab/tree/notebooks/auto_examples/ensemble/plot_gradient_boosting_early_stopping.ipynb :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_gradient_boosting_early_stopping.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_gradient_boosting_early_stopping.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_gradient_boosting_early_stopping.zip ` .. include:: plot_gradient_boosting_early_stopping.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_