.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/linear_model/plot_sgd_comparison.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_linear_model_plot_sgd_comparison.py: ================================== 比较各种在线求解器 ================================== 一个展示不同在线求解器在手写数字数据集上表现的示例。 .. GENERATED FROM PYTHON SOURCE LINES 7-70 .. image-sg:: /auto_examples/linear_model/images/sphx_glr_plot_sgd_comparison_001.png :alt: plot sgd comparison :srcset: /auto_examples/linear_model/images/sphx_glr_plot_sgd_comparison_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none training SGD training ASGD training Perceptron training Passive-Aggressive I training Passive-Aggressive II training SAG | .. code-block:: Python # 作者:scikit-learn 开发者 # SPDX-License-Identifier: BSD-3-Clause import matplotlib.pyplot as plt import numpy as np from sklearn import datasets from sklearn.linear_model import ( LogisticRegression, PassiveAggressiveClassifier, Perceptron, SGDClassifier, ) from sklearn.model_selection import train_test_split heldout = [0.95, 0.90, 0.75, 0.50, 0.01] # 拟合和评估估计器的轮次。 rounds = 10 X, y = datasets.load_digits(return_X_y=True) classifiers = [ ("SGD", SGDClassifier(max_iter=110)), ("ASGD", SGDClassifier(max_iter=110, average=True)), ("Perceptron", Perceptron(max_iter=110)), ( "Passive-Aggressive I", PassiveAggressiveClassifier(max_iter=110, loss="hinge", C=1.0, tol=1e-4), ), ( "Passive-Aggressive II", PassiveAggressiveClassifier( max_iter=110, loss="squared_hinge", C=1.0, tol=1e-4 ), ), ( "SAG", LogisticRegression(max_iter=110, solver="sag", tol=1e-1, C=1.0e4 / X.shape[0]), ), ] xx = 1.0 - np.array(heldout) for name, clf in classifiers: print("training %s" % name) rng = np.random.RandomState(42) yy = [] for i in heldout: yy_ = [] for r in range(rounds): X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=i, random_state=rng ) clf.fit(X_train, y_train) y_pred = clf.predict(X_test) yy_.append(1 - np.mean(y_pred == y_test)) yy.append(np.mean(yy_)) plt.plot(xx, yy, label=name) plt.legend(loc="upper right") plt.xlabel("Proportion train") plt.ylabel("Test Error Rate") plt.show() .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 21.425 seconds) .. _sphx_glr_download_auto_examples_linear_model_plot_sgd_comparison.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/main?urlpath=lab/tree/notebooks/auto_examples/linear_model/plot_sgd_comparison.ipynb :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_sgd_comparison.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_sgd_comparison.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_sgd_comparison.zip ` .. include:: plot_sgd_comparison.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_