混淆矩阵#

使用混淆矩阵评估分类器在鸢尾花数据集上的输出质量示例。对角线元素表示预测标签等于真实标签的点的数量,而非对角线元素则表示分类器错误标记的点。混淆矩阵的对角线值越高越好,表明正确预测的数量多。

图中显示了按类别支持大小(每个类别中的元素数量)进行归一化和未归一化的混淆矩阵。在类别不平衡的情况下,这种归一化可以更直观地解释哪个类别被错误分类。

这里的结果不如预期,因为我们选择的正则化参数C不是最佳的。在实际应用中,这个参数通常使用 调整估计器的超参数 选择。

  • Confusion matrix, without normalization
  • Normalized confusion matrix
Confusion matrix, without normalization
[[13  0  0]
 [ 0 10  6]
 [ 0  0  9]]
Normalized confusion matrix
[[1.   0.   0.  ]
 [0.   0.62 0.38]
 [0.   0.   1.  ]]

import matplotlib.pyplot as plt
import numpy as np

from sklearn import datasets, svm
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.model_selection import train_test_split

# 导入一些数据来玩玩
iris = datasets.load_iris()
X = iris.data
y = iris.target
class_names = iris.target_names

# 将数据分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

# 运行分类器,使用正则化过度的模型(C 值过低)以观察其对结果的影响
classifier = svm.SVC(kernel="linear", C=0.01).fit(X_train, y_train)

np.set_printoptions(precision=2)

# 绘制未归一化的混淆矩阵
titles_options = [
    ("Confusion matrix, without normalization", None),
    ("Normalized confusion matrix", "true"),
]
for title, normalize in titles_options:
    disp = ConfusionMatrixDisplay.from_estimator(
        classifier,
        X_test,
        y_test,
        display_labels=class_names,
        cmap=plt.cm.Blues,
        normalize=normalize,
    )
    disp.ax_.set_title(title)

    print(title)
    print(disp.confusion_matrix)

plt.show()

Total running time of the script: (0 minutes 0.092 seconds)

Related examples

最近邻分类

最近邻分类

绘制在鸢尾花数据集上训练的决策树的决策边界

绘制在鸢尾花数据集上训练的决策树的决策边界

在鸢尾花数据集上绘制不同的SVM分类器

在鸢尾花数据集上绘制不同的SVM分类器

绘制验证曲线

绘制验证曲线

Gallery generated by Sphinx-Gallery