嵌套与非嵌套交叉验证#

本示例比较了在鸢尾花数据集分类器上使用非嵌套和嵌套交叉验证策略。嵌套交叉验证(CV)通常用于训练需要优化超参数的模型。嵌套CV估计了基础模型及其(超)参数搜索的泛化误差。选择最大化非嵌套CV的参数会使模型偏向数据集,从而产生过于乐观的评分。

不使用嵌套CV的模型选择使用相同的数据来调整模型参数和评估模型性能。因此,信息可能会“泄漏”到模型中并导致数据过拟合。这种效应的大小主要取决于数据集的大小和模型的稳定性。有关这些问题的分析,请参见Cawley和Talbot [1]

为避免此问题,嵌套CV有效地使用了一系列训练/验证/测试集拆分。在内循环中(此处由:class:GridSearchCV <sklearn.model_selection.GridSearchCV> 执行),通过将模型拟合到每个训练集来近似最大化评分,然后在验证集上直接最大化选择(超)参数。在外循环中(此处在:func:cross_val_score <sklearn.model_selection.cross_val_score> 中),通过对多个数据集拆分的测试集评分进行平均来估计泛化误差。

下面的示例使用具有非线性核的支持向量分类器,通过网格搜索构建具有优化超参数的模型。我们通过比较非嵌套和嵌套CV策略的评分差异来比较它们的性能。

参考文献

Non-Nested and Nested Cross Validation on Iris Dataset
Average difference of 0.007361 with std. dev. of 0.007760.

import numpy as np
from matplotlib import pyplot as plt

from sklearn.datasets import load_iris
from sklearn.model_selection import GridSearchCV, KFold, cross_val_score
from sklearn.svm import SVC

# Number of random trials
#
#
NUM_TRIALS = 30

# 加载数据集
iris = load_iris()
X_iris = iris.data
y_iris = iris.target

# 设置要优化的参数的可能值
p_grid = {"C": [1, 10, 100], "gamma": [0.01, 0.1]}

# 我们将使用带有“rbf”核的支持向量分类器
svm = SVC(kernel="rbf")

# 用于存储分数的数组
non_nested_scores = np.zeros(NUM_TRIALS)
nested_scores = np.zeros(NUM_TRIALS)

# Loop for each trial
for i in range(NUM_TRIALS):
    # 选择内外循环的交叉验证技术,与数据集无关。
    # 例如 "GroupKFold", "LeaveOneOut", "LeaveOneGroupOut" 等。
    inner_cv = KFold(n_splits=4, shuffle=True, random_state=i)
    outer_cv = KFold(n_splits=4, shuffle=True, random_state=i)

    # 非嵌套参数搜索与评分
    clf = GridSearchCV(estimator=svm, param_grid=p_grid, cv=outer_cv)
    clf.fit(X_iris, y_iris)
    non_nested_scores[i] = clf.best_score_

    # 嵌套交叉验证与参数优化
    clf = GridSearchCV(estimator=svm, param_grid=p_grid, cv=inner_cv)
    nested_score = cross_val_score(clf, X=X_iris, y=y_iris, cv=outer_cv)
    nested_scores[i] = nested_score.mean()

score_difference = non_nested_scores - nested_scores

print(
    "Average difference of {:6f} with std. dev. of {:6f}.".format(
        score_difference.mean(), score_difference.std()
    )
)

# 绘制嵌套和非嵌套交叉验证中每次试验的得分
plt.figure()
plt.subplot(211)
(non_nested_scores_line,) = plt.plot(non_nested_scores, color="r")
(nested_line,) = plt.plot(nested_scores, color="b")
plt.ylabel("score", fontsize="14")
plt.legend(
    [non_nested_scores_line, nested_line],
    ["Non-Nested CV", "Nested CV"],
    bbox_to_anchor=(0, 0.4, 0.5, 0),
)
plt.title(
    "Non-Nested and Nested Cross Validation on Iris Dataset",
    x=0.5,
    y=1.1,
    fontsize="15",
)

# 绘制差异的条形图。
plt.subplot(212)
difference_plot = plt.bar(range(NUM_TRIALS), score_difference)
plt.xlabel("Individual Trial #")
plt.legend(
    [difference_plot],
    ["Non-Nested CV - Nested CV Score"],
    bbox_to_anchor=(0, 1, 0.8, 0),
)
plt.ylabel("score difference", fontsize="14")

plt.show()

Total running time of the script: (0 minutes 2.556 seconds)

Related examples

通过排列检验分类评分的显著性

通过排列检验分类评分的显著性

连接多种特征提取方法

连接多种特征提取方法

多类训练元估计器概述

多类训练元估计器概述

在 scikit-learn 中可视化交叉验证行为

在 scikit-learn 中可视化交叉验证行为

Gallery generated by Sphinx-Gallery