plot_learning_curves: 从训练集和测试集绘制学习曲线
一个用于绘制分类器学习曲线的函数。学习曲线对于分析模型是否遭受过拟合或欠拟合(高方差或高偏差)非常有用。该函数可以通过以下方式导入:
# 学习曲线绘制
本文将介绍如何使用 `mlxtend` 库绘制学习曲线。
此函数使用基于训练集和测试集(或验证集)的传统留出法。测试集保持不变,而训练集的大小逐渐增加。模型在训练集(大小不同)上进行拟合,并在相同的测试集上进行评估。
学习曲线可以用于诊断过拟合:
- 如果训练性能和测试性能之间存在较大差距,则模型很可能存在过拟合问题。
- 如果训练误差和测试误差都很大,则模型很可能对数据产生了欠拟合。
学习曲线还可以用来判断收集更多数据是否有用。更多内容见下面的示例1。
参考文献
-
示例 1
以下代码演示了我们如何为MNIST数据集的5000个样本子集构建学习曲线。使用4000个样本进行训练,1000个样本保留用于测试。
from mlxtend.plotting import plot_learning_curves
import matplotlib.pyplot as plt
from mlxtend.data import mnist_data
from mlxtend.preprocessing import shuffle_arrays_unison
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
# 加载一些示例数据
X, y = mnist_data()
X, y = shuffle_arrays_unison(arrays=[X, y], random_seed=123)
X_train, X_test = X[:4000], X[4000:]
y_train, y_test = y[:4000], y[4000:]
clf = KNeighborsClassifier(n_neighbors=7)
plot_learning_curves(X_train, y_train, X_test, y_test, clf)
plt.show()
正如我们从上面的图中看到的,KNN模型可以从额外的训练数据中受益。也就是说,斜率表明如果我们有一个更大的训练集,测试集的错误可能会进一步减少。
此外,根据训练集和测试集性能之间的差距,该模型稍微过拟合。这可能通过增加KNN中的邻居数量(n_neighbors
)来解决。
虽然这与分析分类器的性能无关,但大约20%的训练集大小显示该模型出现了欠拟合(训练误差和测试误差都较大),这可能是由于数据集大小过小造成的。
API
plot_learning_curves(X_train, y_train, X_test, y_test, clf, train_marker='o', test_marker='^', scoring='misclassification error', suppress_plot=False, print_model=True, title_fontsize=12, style='default', legend_loc='best')
Plots learning curves of a classifier.
Parameters
-
X_train
: array-like, shape = [n_samples, n_features]Feature matrix of the training dataset.
-
y_train
: array-like, shape = [n_samples]True class labels of the training dataset.
-
X_test
: array-like, shape = [n_samples, n_features]Feature matrix of the test dataset.
-
y_test
: array-like, shape = [n_samples]True class labels of the test dataset.
-
clf
: Classifier object. Must have a .predict .fit method. -
train_marker
: str (default: 'o')Marker for the training set line plot.
-
test_marker
: str (default: '^')Marker for the test set line plot.
-
scoring
: str (default: 'misclassification error')If not 'misclassification error', accepts the following metrics (from scikit-learn): {'accuracy', 'average_precision', 'f1_micro', 'f1_macro', 'f1_weighted', 'f1_samples', 'log_loss', 'precision', 'recall', 'roc_auc', 'adjusted_rand_score', 'mean_absolute_error', 'mean_squared_error', 'median_absolute_error', 'r2'}
-
suppress_plot=False
: bool (default: False)Suppress matplotlib plots if True. Recommended for testing purposes.
-
print_model
: bool (default: True)Print model parameters in plot title if True.
-
style
: str (default: 'default')Matplotlib style. For more styles, please see https://matplotlib.org/stable/gallery/style_sheets/style_sheets_reference.html
-
legend_loc
: str (default: 'best')Where to place the plot legend: {'best', 'upper left', 'upper right', 'lower left', 'lower right'}
Returns
errors
: (training_error, test_error): tuple of lists
Examples
For usage examples, please see https://rasbt.github.io/mlxtend/user_guide/plotting/plot_learning_curves/