plot_confusion_matrix: 可视化混淆矩阵

用于通过matplotlib可视化混淆矩阵的工具函数

> `from mlxtend.plotting import plot_confusion_matrix`

概述

混淆矩阵

有关混淆矩阵的更多信息,请参见mlxtend.evaluate.confusion_matrix

参考文献

示例 1 - 二进制

from mlxtend.plotting import plot_confusion_matrix
import matplotlib.pyplot as plt
import numpy as np

binary1 = np.array([[4, 1],
                    [1, 2]])

fig, ax = plot_confusion_matrix(conf_mat=binary1)
plt.show()

png

binary2 = np.array([[21, 1],
                    [3, 1]])

fig, ax = plot_confusion_matrix(conf_mat=binary2, figsize=(2, 2))
plt.show()

png

示例 2 - 带颜色条的二进制绝对值和相对值

binary = np.array([[4, 1],
                   [1, 2]])

fig, ax = plot_confusion_matrix(conf_mat=binary,
                                show_absolute=True,
                                show_normed=True,
                                colorbar=True)
plt.show()

png

示例 3 - 多类别相对

multiclass = np.array([[2, 1, 0, 0],
                       [1, 2, 0, 0],
                       [0, 0, 1, 0],
                       [0, 0, 0, 1]])

fig, ax = plot_confusion_matrix(conf_mat=multiclass,
                                colorbar=True,
                                show_absolute=False,
                                show_normed=True)
plt.show()

png

示例 4 - 添加类名

multiclass = np.array([[2, 1, 0, 0],
                       [1, 2, 0, 0],
                       [0, 0, 1, 0],
                       [0, 0, 0, 1]])

class_names = ['class a', 'class b', 'class c', 'class d']

fig, ax = plot_confusion_matrix(conf_mat=multiclass,
                                colorbar=True,
                                show_absolute=False,
                                show_normed=True,
                                class_names=class_names)
plt.show()

png

示例 5 - 更改颜色映射和字体颜色

Matplotlib颜色图可以通过cmap参数作为替代颜色图选择。可以在这里找到颜色图的列表:https://matplotlib.org/stable/tutorials/colors/colormaps.html

multiclass = np.array([[2, 1, 0, 0],
                       [1, 2, 0, 0],
                       [0, 0, 1, 0],
                       [0, 0, 0, 1]])

fig, ax = plot_confusion_matrix(conf_mat=multiclass,
                                colorbar=True,
                                cmap='summer')

plt.show()

png

如上所示,字体颜色阈值可能对某些颜色映射不起作用。默认情况下,所有大于最大单元值的0.5倍的值被转换为白色,而所有等于或小于最大单元值的0.5倍的值被转换为黑色。

如果您想将所有以上的值更改为,例如,白色,您可以将颜色阈值设置为负数。或者,如果您想将所有字体颜色设置为黑色,请选择一个等于或大于1的阈值。

fig, ax = plot_confusion_matrix(conf_mat=multiclass,
                                colorbar=True,
                                fontcolor_threshold=1,
                                cmap='summer')

plt.show()

png

示例 6 - 归一化色图以突出非对角线部分

假设我们有以下高准确率分类器的混淆矩阵:

class_dict = {0: 'airplane',
              1: 'automobile',
              2: 'bird',
              3: 'cat',
              4: 'deer',
              5: 'dog',
              6: 'frog'}

cmat = np.array([[972, 0, 1, 1, 1, 1, 3],
                 [0, 1123, 3, 1, 0, 1, 2],
                 [2, 0, 1025, 0, 0, 0, 1],
                 [0, 0, 0, 1005, 0, 2, 0],
                 [0, 1, 1, 0, 967, 0, 4],
                 [0, 0, 0, 6, 0, 881, 3],
                 [2, 3, 0, 1, 3, 4, 941]])

fig, ax = plot_confusion_matrix(
    conf_mat=cmat,
    class_names=class_dict.values(),
)

png

很难注意到模型出错的单元格。使用对数归一化的彩色图,这些对角线外的错误在一眼看来就更容易被发现:

import matplotlib

fig, ax = plot_confusion_matrix(
    conf_mat=cmat,
    class_names=class_dict.values(),
    norm_colormap=matplotlib.colors.LogNorm()  
)

png

API

plot_confusion_matrix(conf_mat, hide_spines=False, hide_ticks=False, figsize=None, cmap=None, colorbar=False, show_absolute=True, show_normed=False, class_names=None, figure=None, axis=None, fontcolor_threshold=0.5)

Plot a confusion matrix via matplotlib.

Parameters

Returns

Examples

For usage examples, please see https://rasbt.github.io/mlxtend/user_guide/plotting/plot_confusion_matrix/