.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/model_selection/plot_nested_cross_validation_iris.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_model_selection_plot_nested_cross_validation_iris.py: ========================================= 嵌套与非嵌套交叉验证 ========================================= 本示例比较了在鸢尾花数据集分类器上使用非嵌套和嵌套交叉验证策略。嵌套交叉验证(CV)通常用于训练需要优化超参数的模型。嵌套CV估计了基础模型及其(超)参数搜索的泛化误差。选择最大化非嵌套CV的参数会使模型偏向数据集,从而产生过于乐观的评分。 不使用嵌套CV的模型选择使用相同的数据来调整模型参数和评估模型性能。因此,信息可能会“泄漏”到模型中并导致数据过拟合。这种效应的大小主要取决于数据集的大小和模型的稳定性。有关这些问题的分析,请参见Cawley和Talbot [1]_。 为避免此问题,嵌套CV有效地使用了一系列训练/验证/测试集拆分。在内循环中(此处由:class:`GridSearchCV ` 执行),通过将模型拟合到每个训练集来近似最大化评分,然后在验证集上直接最大化选择(超)参数。在外循环中(此处在:func:`cross_val_score ` 中),通过对多个数据集拆分的测试集评分进行平均来估计泛化误差。 下面的示例使用具有非线性核的支持向量分类器,通过网格搜索构建具有优化超参数的模型。我们通过比较非嵌套和嵌套CV策略的评分差异来比较它们的性能。 .. seealso:: - :ref:`cross_validation` - :ref:`grid_search` .. rubric:: 参考文献 .. [1] `Cawley, G.C.; Talbot, N.L.C. 关于模型选择中的过拟合及其在性能评估中的选择偏差。 J. Mach. Learn. Res 2010,11, 2079-2107. `_ .. GENERATED FROM PYTHON SOURCE LINES 26-109 .. image-sg:: /auto_examples/model_selection/images/sphx_glr_plot_nested_cross_validation_iris_001.png :alt: Non-Nested and Nested Cross Validation on Iris Dataset :srcset: /auto_examples/model_selection/images/sphx_glr_plot_nested_cross_validation_iris_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Average difference of 0.007361 with std. dev. of 0.007760. | .. code-block:: Python import numpy as np from matplotlib import pyplot as plt from sklearn.datasets import load_iris from sklearn.model_selection import GridSearchCV, KFold, cross_val_score from sklearn.svm import SVC # Number of random trials # # NUM_TRIALS = 30 # 加载数据集 iris = load_iris() X_iris = iris.data y_iris = iris.target # 设置要优化的参数的可能值 p_grid = {"C": [1, 10, 100], "gamma": [0.01, 0.1]} # 我们将使用带有“rbf”核的支持向量分类器 svm = SVC(kernel="rbf") # 用于存储分数的数组 non_nested_scores = np.zeros(NUM_TRIALS) nested_scores = np.zeros(NUM_TRIALS) # Loop for each trial for i in range(NUM_TRIALS): # 选择内外循环的交叉验证技术,与数据集无关。 # 例如 "GroupKFold", "LeaveOneOut", "LeaveOneGroupOut" 等。 inner_cv = KFold(n_splits=4, shuffle=True, random_state=i) outer_cv = KFold(n_splits=4, shuffle=True, random_state=i) # 非嵌套参数搜索与评分 clf = GridSearchCV(estimator=svm, param_grid=p_grid, cv=outer_cv) clf.fit(X_iris, y_iris) non_nested_scores[i] = clf.best_score_ # 嵌套交叉验证与参数优化 clf = GridSearchCV(estimator=svm, param_grid=p_grid, cv=inner_cv) nested_score = cross_val_score(clf, X=X_iris, y=y_iris, cv=outer_cv) nested_scores[i] = nested_score.mean() score_difference = non_nested_scores - nested_scores print( "Average difference of {:6f} with std. dev. of {:6f}.".format( score_difference.mean(), score_difference.std() ) ) # 绘制嵌套和非嵌套交叉验证中每次试验的得分 plt.figure() plt.subplot(211) (non_nested_scores_line,) = plt.plot(non_nested_scores, color="r") (nested_line,) = plt.plot(nested_scores, color="b") plt.ylabel("score", fontsize="14") plt.legend( [non_nested_scores_line, nested_line], ["Non-Nested CV", "Nested CV"], bbox_to_anchor=(0, 0.4, 0.5, 0), ) plt.title( "Non-Nested and Nested Cross Validation on Iris Dataset", x=0.5, y=1.1, fontsize="15", ) # 绘制差异的条形图。 plt.subplot(212) difference_plot = plt.bar(range(NUM_TRIALS), score_difference) plt.xlabel("Individual Trial #") plt.legend( [difference_plot], ["Non-Nested CV - Nested CV Score"], bbox_to_anchor=(0, 1, 0.8, 0), ) plt.ylabel("score difference", fontsize="14") plt.show() .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 2.556 seconds) .. _sphx_glr_download_auto_examples_model_selection_plot_nested_cross_validation_iris.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/main?urlpath=lab/tree/notebooks/auto_examples/model_selection/plot_nested_cross_validation_iris.ipynb :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_nested_cross_validation_iris.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_nested_cross_validation_iris.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_nested_cross_validation_iris.zip ` .. include:: plot_nested_cross_validation_iris.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_