.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/linear_model/plot_sgd_early_stopping.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_linear_model_plot_sgd_early_stopping.py: ============================================= 随机梯度下降的早停 ============================================= 随机梯度下降是一种优化技术,它以随机的方式最小化损失函数,逐样本地执行梯度下降步骤。特别地,它是一种非常有效的线性模型拟合方法。 作为一种随机方法,损失函数不一定在每次迭代中都减少,收敛仅在期望中得到保证。因此,监控损失函数的收敛可能会很困难。 另一种方法是监控验证分数的收敛。在这种情况下,输入数据被分为训练集和验证集。然后在训练集上拟合模型,停止准则基于在验证集上计算的预测分数。这使我们能够找到足够构建一个能够很好地泛化到未见数据的模型的最少迭代次数,并减少过拟合训练数据的机会。 如果 ``early_stopping=True`` ,则激活此早停策略;否则,停止准则仅使用整个输入数据上的训练损失。为了更好地控制早停策略,我们可以指定一个参数 ``validation_fraction`` ,该参数设置我们保留用于计算验证分数的输入数据集的比例。优化将继续,直到在最后 ``n_iter_no_change`` 次迭代期间验证分数未能至少提高 ``tol`` 。实际的迭代次数可在属性 ``n_iter_`` 中获得。 此示例说明了如何在:class:`~sklearn.linear_model.SGDClassifier` 模型中使用早停,以实现与不使用早停构建的模型几乎相同的准确性。这可以显著减少训练时间。请注意,由于验证停止准则保留了一些训练数据,即使在早期迭代中,停止准则之间的分数也会有所不同。 .. GENERATED FROM PYTHON SOURCE LINES 17-133 .. rst-class:: sphx-glr-horizontal * .. image-sg:: /auto_examples/linear_model/images/sphx_glr_plot_sgd_early_stopping_001.png :alt: Train score, Test score :srcset: /auto_examples/linear_model/images/sphx_glr_plot_sgd_early_stopping_001.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/linear_model/images/sphx_glr_plot_sgd_early_stopping_002.png :alt: n_iter_, Fit time (sec) :srcset: /auto_examples/linear_model/images/sphx_glr_plot_sgd_early_stopping_002.png :class: sphx-glr-multi-img .. rst-class:: sphx-glr-script-out .. code-block:: none No stopping criterion: ................................................. Training loss: ................................................. Validation score: ................................................. | .. code-block:: Python # 作者:scikit-learn 开发者 # SPDX-License-Identifier: BSD-3-Clause import sys import time import matplotlib.pyplot as plt import numpy as np import pandas as pd from sklearn import linear_model from sklearn.datasets import fetch_openml from sklearn.exceptions import ConvergenceWarning from sklearn.model_selection import train_test_split from sklearn.utils import shuffle from sklearn.utils._testing import ignore_warnings def load_mnist(n_samples=None, class_0="0", class_1="8"): """加载MNIST,选择两个类别,打乱顺序并仅返回n_samples。""" # 从 http://openml.org/d/554 加载数据 mnist = fetch_openml("mnist_784", version=1, as_frame=False) # 只选取两个类别进行二分类 mask = np.logical_or(mnist.target == class_0, mnist.target == class_1) X, y = shuffle(mnist.data[mask], mnist.target[mask], random_state=42) if n_samples is not None: X, y = X[:n_samples], y[:n_samples] return X, y @ignore_warnings(category=ConvergenceWarning) def fit_and_score(estimator, max_iter, X_train, X_test, y_train, y_test): """在训练集上拟合估计器,并在两个数据集上进行评分""" estimator.set_params(max_iter=max_iter) estimator.set_params(random_state=0) start = time.time() estimator.fit(X_train, y_train) fit_time = time.time() - start n_iter = estimator.n_iter_ train_score = estimator.score(X_train, y_train) test_score = estimator.score(X_test, y_test) return fit_time, n_iter, train_score, test_score # 定义要比较的估计量 estimator_dict = { "No stopping criterion": linear_model.SGDClassifier(n_iter_no_change=3), "Training loss": linear_model.SGDClassifier( early_stopping=False, n_iter_no_change=3, tol=0.1 ), "Validation score": linear_model.SGDClassifier( early_stopping=True, n_iter_no_change=3, tol=0.0001, validation_fraction=0.2 ), } # 加载数据集 X, y = load_mnist(n_samples=10000) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0) results = [] for estimator_name, estimator in estimator_dict.items(): print(estimator_name + ": ", end="") for max_iter in range(1, 50): print(".", end="") sys.stdout.flush() fit_time, n_iter, train_score, test_score = fit_and_score( estimator, max_iter, X_train, X_test, y_train, y_test ) results.append( (estimator_name, max_iter, fit_time, n_iter, train_score, test_score) ) print("") # 将结果转换为 pandas 数据框以便于绘图 columns = [ "Stopping criterion", "max_iter", "Fit time (sec)", "n_iter_", "Train score", "Test score", ] results_df = pd.DataFrame(results, columns=columns) # 定义要绘制的内容 lines = "Stopping criterion" x_axis = "max_iter" styles = ["-.", "--", "-"] # 第一个图:训练和测试分数 fig, axes = plt.subplots(nrows=1, ncols=2, sharey=True, figsize=(12, 4)) for ax, y_axis in zip(axes, ["Train score", "Test score"]): for style, (criterion, group_df) in zip(styles, results_df.groupby(lines)): group_df.plot(x=x_axis, y=y_axis, label=criterion, ax=ax, style=style) ax.set_title(y_axis) ax.legend(title=lines) fig.tight_layout() # 第二个图:迭代次数和拟合时间 fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 4)) for ax, y_axis in zip(axes, ["n_iter_", "Fit time (sec)"]): for style, (criterion, group_df) in zip(styles, results_df.groupby(lines)): group_df.plot(x=x_axis, y=y_axis, label=criterion, ax=ax, style=style) ax.set_title(y_axis) ax.legend(title=lines) fig.tight_layout() plt.show() .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 40.061 seconds) .. _sphx_glr_download_auto_examples_linear_model_plot_sgd_early_stopping.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/main?urlpath=lab/tree/notebooks/auto_examples/linear_model/plot_sgd_early_stopping.ipynb :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_sgd_early_stopping.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_sgd_early_stopping.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_sgd_early_stopping.zip ` .. include:: plot_sgd_early_stopping.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_