.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/model_selection/plot_cv_indices.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_model_selection_plot_cv_indices.py: 在 scikit-learn 中可视化交叉验证行为 ======================================== 选择合适的交叉验证对象是正确拟合模型的关键部分。有很多方法可以将数据分成训练集和测试集,以避免模型过拟合,标准化测试集中的组数等。 此示例可视化了几种常见的 scikit-learn 对象的行为以供比较。 .. GENERATED FROM PYTHON SOURCE LINES 10-31 .. code-block:: Python import matplotlib.pyplot as plt import numpy as np from matplotlib.patches import Patch from sklearn.model_selection import ( GroupKFold, GroupShuffleSplit, KFold, ShuffleSplit, StratifiedGroupKFold, StratifiedKFold, StratifiedShuffleSplit, TimeSeriesSplit, ) rng = np.random.RandomState(1338) cmap_data = plt.cm.Paired cmap_cv = plt.cm.coolwarm n_splits = 4 .. GENERATED FROM PYTHON SOURCE LINES 32-40 可视化我们的数据 ------------------ 首先,我们必须了解数据的结构。它有100个随机生成的输入数据点,3个不均匀分布在数据点上的类别,以及10个均匀分布在数据点上的“组”。 正如我们将看到的,一些交叉验证对象对有标签的数据执行特定操作,另一些则对分组数据有不同的处理方式,还有一些则不使用这些信息。 首先,我们将可视化我们的数据。 .. GENERATED FROM PYTHON SOURCE LINES 40-82 .. code-block:: Python # 生成班级/组数据 n_points = 100 X = rng.randn(100, 10) percentiles_classes = [0.1, 0.3, 0.6] y = np.hstack([[ii] * int(100 * perc) for ii, perc in enumerate(percentiles_classes)]) # 生成不均匀的组 group_prior = rng.dirichlet([2] * 10) groups = np.repeat(np.arange(10), rng.multinomial(100, group_prior)) def visualize_groups(classes, groups, name): # 可视化数据集组 fig, ax = plt.subplots() ax.scatter( range(len(groups)), [0.5] * len(groups), c=groups, marker="_", lw=50, cmap=cmap_data, ) ax.scatter( range(len(groups)), [3.5] * len(groups), c=classes, marker="_", lw=50, cmap=cmap_data, ) ax.set( ylim=[-1, 5], yticks=[0.5, 3.5], yticklabels=["Data\ngroup", "Data\nclass"], xlabel="Sample index", ) visualize_groups(y, groups, "no groups") .. image-sg:: /auto_examples/model_selection/images/sphx_glr_plot_cv_indices_001.png :alt: plot cv indices :srcset: /auto_examples/model_selection/images/sphx_glr_plot_cv_indices_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 83-87 定义一个函数来可视化交叉验证行为 -------------------------------------------------------- 我们将定义一个函数,让我们可视化每个交叉验证对象的行为。我们将对数据进行4次拆分。在每次拆分中,我们将可视化为训练集(蓝色)和测试集(红色)选择的索引。 .. GENERATED FROM PYTHON SOURCE LINES 87-135 .. code-block:: Python def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10): """为交叉验证对象的索引创建一个示例图。""" use_groups = "Group" in type(cv).__name__ groups = group if use_groups else None # 生成每个交叉验证分割的训练/测试可视化 for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=groups)): # 用训练/测试组填充索引 indices = np.array([np.nan] * len(X)) indices[tt] = 1 indices[tr] = 0 # 可视化结果 ax.scatter( range(len(indices)), [ii + 0.5] * len(indices), c=indices, marker="_", lw=lw, cmap=cmap_cv, vmin=-0.2, vmax=1.2, ) # 绘制数据类别和组别在末尾 ax.scatter( range(len(X)), [ii + 1.5] * len(X), c=y, marker="_", lw=lw, cmap=cmap_data ) ax.scatter( range(len(X)), [ii + 2.5] * len(X), c=group, marker="_", lw=lw, cmap=cmap_data ) # Formatting yticklabels = list(range(n_splits)) + ["class", "group"] ax.set( yticks=np.arange(n_splits + 2) + 0.5, yticklabels=yticklabels, xlabel="Sample index", ylabel="CV iteration", ylim=[n_splits + 2.2, -0.2], xlim=[0, 100], ) ax.set_title("{}".format(type(cv).__name__), fontsize=15) return ax .. GENERATED FROM PYTHON SOURCE LINES 136-137 让我们看看 :class:`~sklearn.model_selection.KFold` 交叉验证对象的效果如何: .. GENERATED FROM PYTHON SOURCE LINES 137-143 .. code-block:: Python fig, ax = plt.subplots() cv = KFold(n_splits) plot_cv_indices(cv, X, y, groups, ax, n_splits) .. image-sg:: /auto_examples/model_selection/images/sphx_glr_plot_cv_indices_002.png :alt: KFold :srcset: /auto_examples/model_selection/images/sphx_glr_plot_cv_indices_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 144-149 正如你所见,默认情况下,KFold 交叉验证迭代器不会考虑数据点的类别或分组。我们可以通过以下方式进行更改: - ``StratifiedKFold`` 用于保持每个类别样本的比例。 - ``GroupKFold`` 用于确保同一组不会出现在不同的折叠中。 - ``StratifiedGroupKFold`` 用于在保持 ``GroupKFold`` 约束的同时,尝试返回分层折叠。 .. GENERATED FROM PYTHON SOURCE LINES 149-163 .. code-block:: Python cvs = [StratifiedKFold, GroupKFold, StratifiedGroupKFold] for cv in cvs: fig, ax = plt.subplots(figsize=(6, 3)) plot_cv_indices(cv(n_splits), X, y, groups, ax, n_splits) ax.legend( [Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))], ["Testing set", "Training set"], loc=(1.02, 0.8), ) # 使图例适合 plt.tight_layout() fig.subplots_adjust(right=0.7) .. rst-class:: sphx-glr-horizontal * .. image-sg:: /auto_examples/model_selection/images/sphx_glr_plot_cv_indices_003.png :alt: StratifiedKFold :srcset: /auto_examples/model_selection/images/sphx_glr_plot_cv_indices_003.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/model_selection/images/sphx_glr_plot_cv_indices_004.png :alt: GroupKFold :srcset: /auto_examples/model_selection/images/sphx_glr_plot_cv_indices_004.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/model_selection/images/sphx_glr_plot_cv_indices_005.png :alt: StratifiedGroupKFold :srcset: /auto_examples/model_selection/images/sphx_glr_plot_cv_indices_005.png :class: sphx-glr-multi-img .. GENERATED FROM PYTHON SOURCE LINES 164-171 接下来我们将为多个交叉验证迭代器可视化这种行为。 可视化多个交叉验证对象的索引 让我们直观地比较多种 scikit-learn 交叉验证对象的行为。下面我们将遍历几种常见的交叉验证对象,直观地展示每种对象的行为。 注意,有些使用了组/类信息,而有些则没有使用。 .. GENERATED FROM PYTHON SOURCE LINES 171-198 .. code-block:: Python cvs = [ KFold, GroupKFold, ShuffleSplit, StratifiedKFold, StratifiedGroupKFold, GroupShuffleSplit, StratifiedShuffleSplit, TimeSeriesSplit, ] for cv in cvs: this_cv = cv(n_splits=n_splits) fig, ax = plt.subplots(figsize=(6, 3)) plot_cv_indices(this_cv, X, y, groups, ax, n_splits) ax.legend( [Patch(color=cmap_cv(0.8)), Patch(color=cmap_cv(0.02))], ["Testing set", "Training set"], loc=(1.02, 0.8), ) # 使图例适合 plt.tight_layout() fig.subplots_adjust(right=0.7) plt.show() .. rst-class:: sphx-glr-horizontal * .. image-sg:: /auto_examples/model_selection/images/sphx_glr_plot_cv_indices_006.png :alt: KFold :srcset: /auto_examples/model_selection/images/sphx_glr_plot_cv_indices_006.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/model_selection/images/sphx_glr_plot_cv_indices_007.png :alt: GroupKFold :srcset: /auto_examples/model_selection/images/sphx_glr_plot_cv_indices_007.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/model_selection/images/sphx_glr_plot_cv_indices_008.png :alt: ShuffleSplit :srcset: /auto_examples/model_selection/images/sphx_glr_plot_cv_indices_008.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/model_selection/images/sphx_glr_plot_cv_indices_009.png :alt: StratifiedKFold :srcset: /auto_examples/model_selection/images/sphx_glr_plot_cv_indices_009.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/model_selection/images/sphx_glr_plot_cv_indices_010.png :alt: StratifiedGroupKFold :srcset: /auto_examples/model_selection/images/sphx_glr_plot_cv_indices_010.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/model_selection/images/sphx_glr_plot_cv_indices_011.png :alt: GroupShuffleSplit :srcset: /auto_examples/model_selection/images/sphx_glr_plot_cv_indices_011.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/model_selection/images/sphx_glr_plot_cv_indices_012.png :alt: StratifiedShuffleSplit :srcset: /auto_examples/model_selection/images/sphx_glr_plot_cv_indices_012.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/model_selection/images/sphx_glr_plot_cv_indices_013.png :alt: TimeSeriesSplit :srcset: /auto_examples/model_selection/images/sphx_glr_plot_cv_indices_013.png :class: sphx-glr-multi-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.519 seconds) .. _sphx_glr_download_auto_examples_model_selection_plot_cv_indices.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/main?urlpath=lab/tree/notebooks/auto_examples/model_selection/plot_cv_indices.ipynb :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_cv_indices.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_cv_indices.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_cv_indices.zip ` .. include:: plot_cv_indices.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_