平衡模型复杂性和交叉验证得分#

此示例通过在最佳准确性得分的1个标准差内找到一个不错的准确性,同时最小化PCA组件的数量来平衡模型复杂性和交叉验证得分[1]。

图中显示了交叉验证得分和PCA组件数量之间的权衡。平衡的情况是当n_components=10且accuracy=0.88时,这落在最佳准确性得分的1个标准差范围内。

[1] Hastie, T., Tibshirani, R., Friedman, J. (2001). 模型评估与选择. 统计学习的要素 (第219-260页). 纽约, 美国: 纽约施普林格公司.

Balance model complexity and cross-validated score
The best_index_ is 2
The n_components selected is 10
The corresponding accuracy score is 0.88

# Author: Wenhao Zhang <wenhaoz@ucla.edu>

import matplotlib.pyplot as plt
import numpy as np

from sklearn.datasets import load_digits
from sklearn.decomposition import PCA
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.svm import LinearSVC


def lower_bound(cv_results):
    """计算最佳 `mean_test_scores` 在1个标准差内的下界。

Parameters
----------
cv_results : numpy(masked) ndarrays 的字典
    参见 `GridSearchCV` 的属性 cv_results_

返回
-------
float
    最佳 `mean_test_score` 在1个标准差内的下界。
"""
    best_score_idx = np.argmax(cv_results["mean_test_score"])

    return (
        cv_results["mean_test_score"][best_score_idx]
        - cv_results["std_test_score"][best_score_idx]
    )


def best_low_complexity(cv_results):
    """平衡模型复杂性与交叉验证得分。

Parameters
----------
cv_results : dict of numpy(masked) ndarrays
    请参阅 `GridSearchCV` 的属性 cv_results_。

返回
------
int
    返回一个模型的索引,该模型在测试得分在最佳 `mean_test_score` 的1个标准差内的同时,具有最少的PCA组件。
"""
    threshold = lower_bound(cv_results)
    candidate_idx = np.flatnonzero(cv_results["mean_test_score"] >= threshold)
    best_idx = candidate_idx[
        cv_results["param_reduce_dim__n_components"][candidate_idx].argmin()
    ]
    return best_idx


pipe = Pipeline(
    [
        ("reduce_dim", PCA(random_state=42)),
        ("classify", LinearSVC(random_state=42, C=0.01)),
    ]
)

param_grid = {"reduce_dim__n_components": [6, 8, 10, 12, 14]}

grid = GridSearchCV(
    pipe,
    cv=10,
    n_jobs=1,
    param_grid=param_grid,
    scoring="accuracy",
    refit=best_low_complexity,
)
X, y = load_digits(return_X_y=True)
grid.fit(X, y)

n_components = grid.cv_results_["param_reduce_dim__n_components"]
test_scores = grid.cv_results_["mean_test_score"]

plt.figure()
plt.bar(n_components, test_scores, width=1.3, color="b")

lower = lower_bound(grid.cv_results_)
plt.axhline(np.max(test_scores), linestyle="--", color="y", label="Best score")
plt.axhline(lower, linestyle="--", color=".5", label="Best score - 1 std")

plt.title("Balance model complexity and cross-validated score")
plt.xlabel("Number of PCA components used")
plt.ylabel("Digit classification accuracy")
plt.xticks(n_components.tolist())
plt.ylim((0, 1.0))
plt.legend(loc="upper left")

best_index_ = grid.best_index_

print("The best_index_ is %d" % best_index_)
print("The n_components selected is %d" % n_components[best_index_])
print(
    "The corresponding accuracy score is %.2f"
    % grid.cv_results_["mean_test_score"][best_index_]
)
plt.show()

Total running time of the script: (0 minutes 0.724 seconds)

Related examples

带交叉验证的递归特征消除

带交叉验证的递归特征消除

网格搜索与交叉验证的自定义重拟合策略

网格搜索与交叉验证的自定义重拟合策略

比较随机森林和直方图梯度提升模型

比较随机森林和直方图梯度提升模型

缓存最近邻

缓存最近邻

Gallery generated by Sphinx-Gallery