plot_learning_curves: 从训练集和测试集绘制学习曲线

一个用于绘制分类器学习曲线的函数。学习曲线对于分析模型是否遭受过拟合或欠拟合(高方差或高偏差)非常有用。该函数可以通过以下方式导入:

# 学习曲线绘制

本文将介绍如何使用 `mlxtend` 库绘制学习曲线。

此函数使用基于训练集和测试集(或验证集)的传统留出法。测试集保持不变,而训练集的大小逐渐增加。模型在训练集(大小不同)上进行拟合,并在相同的测试集上进行评估。

学习曲线可以用于诊断过拟合:

学习曲线还可以用来判断收集更多数据是否有用。更多内容见下面的示例1。

参考文献

-

示例 1

以下代码演示了我们如何为MNIST数据集的5000个样本子集构建学习曲线。使用4000个样本进行训练,1000个样本保留用于测试。

from mlxtend.plotting import plot_learning_curves
import matplotlib.pyplot as plt
from mlxtend.data import mnist_data
from mlxtend.preprocessing import shuffle_arrays_unison
from sklearn.neighbors import KNeighborsClassifier
import numpy as np


# 加载一些示例数据
X, y = mnist_data()
X, y = shuffle_arrays_unison(arrays=[X, y], random_seed=123)
X_train, X_test = X[:4000], X[4000:]
y_train, y_test = y[:4000], y[4000:]

clf = KNeighborsClassifier(n_neighbors=7)

plot_learning_curves(X_train, y_train, X_test, y_test, clf)
plt.show()

png

正如我们从上面的图中看到的,KNN模型可以从额外的训练数据中受益。也就是说,斜率表明如果我们有一个更大的训练集,测试集的错误可能会进一步减少。

此外,根据训练集和测试集性能之间的差距,该模型稍微过拟合。这可能通过增加KNN中的邻居数量(n_neighbors)来解决。

虽然这与分析分类器的性能无关,但大约20%的训练集大小显示该模型出现了欠拟合(训练误差和测试误差都较大),这可能是由于数据集大小过小造成的。

API

plot_learning_curves(X_train, y_train, X_test, y_test, clf, train_marker='o', test_marker='^', scoring='misclassification error', suppress_plot=False, print_model=True, title_fontsize=12, style='default', legend_loc='best')

Plots learning curves of a classifier.

Parameters

Returns

Examples

For usage examples, please see https://rasbt.github.io/mlxtend/user_guide/plotting/plot_learning_curves/