plot_confusion_matrix: 可视化混淆矩阵
用于通过matplotlib可视化混淆矩阵的工具函数
> `from mlxtend.plotting import plot_confusion_matrix`
概述
混淆矩阵
有关混淆矩阵的更多信息,请参见mlxtend.evaluate.confusion_matrix
。
参考文献
- -
示例 1 - 二进制
from mlxtend.plotting import plot_confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
binary1 = np.array([[4, 1],
[1, 2]])
fig, ax = plot_confusion_matrix(conf_mat=binary1)
plt.show()
binary2 = np.array([[21, 1],
[3, 1]])
fig, ax = plot_confusion_matrix(conf_mat=binary2, figsize=(2, 2))
plt.show()
示例 2 - 带颜色条的二进制绝对值和相对值
binary = np.array([[4, 1],
[1, 2]])
fig, ax = plot_confusion_matrix(conf_mat=binary,
show_absolute=True,
show_normed=True,
colorbar=True)
plt.show()
示例 3 - 多类别相对
multiclass = np.array([[2, 1, 0, 0],
[1, 2, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]])
fig, ax = plot_confusion_matrix(conf_mat=multiclass,
colorbar=True,
show_absolute=False,
show_normed=True)
plt.show()
示例 4 - 添加类名
multiclass = np.array([[2, 1, 0, 0],
[1, 2, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]])
class_names = ['class a', 'class b', 'class c', 'class d']
fig, ax = plot_confusion_matrix(conf_mat=multiclass,
colorbar=True,
show_absolute=False,
show_normed=True,
class_names=class_names)
plt.show()
示例 5 - 更改颜色映射和字体颜色
Matplotlib颜色图可以通过cmap
参数作为替代颜色图选择。可以在这里找到颜色图的列表:https://matplotlib.org/stable/tutorials/colors/colormaps.html
multiclass = np.array([[2, 1, 0, 0],
[1, 2, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]])
fig, ax = plot_confusion_matrix(conf_mat=multiclass,
colorbar=True,
cmap='summer')
plt.show()
如上所示,字体颜色阈值可能对某些颜色映射不起作用。默认情况下,所有大于最大单元值的0.5倍的值被转换为白色,而所有等于或小于最大单元值的0.5倍的值被转换为黑色。
如果您想将所有以上的值更改为,例如,白色,您可以将颜色阈值设置为负数。或者,如果您想将所有字体颜色设置为黑色,请选择一个等于或大于1的阈值。
fig, ax = plot_confusion_matrix(conf_mat=multiclass,
colorbar=True,
fontcolor_threshold=1,
cmap='summer')
plt.show()
示例 6 - 归一化色图以突出非对角线部分
假设我们有以下高准确率分类器的混淆矩阵:
class_dict = {0: 'airplane',
1: 'automobile',
2: 'bird',
3: 'cat',
4: 'deer',
5: 'dog',
6: 'frog'}
cmat = np.array([[972, 0, 1, 1, 1, 1, 3],
[0, 1123, 3, 1, 0, 1, 2],
[2, 0, 1025, 0, 0, 0, 1],
[0, 0, 0, 1005, 0, 2, 0],
[0, 1, 1, 0, 967, 0, 4],
[0, 0, 0, 6, 0, 881, 3],
[2, 3, 0, 1, 3, 4, 941]])
fig, ax = plot_confusion_matrix(
conf_mat=cmat,
class_names=class_dict.values(),
)
很难注意到模型出错的单元格。使用对数归一化的彩色图,这些对角线外的错误在一眼看来就更容易被发现:
import matplotlib
fig, ax = plot_confusion_matrix(
conf_mat=cmat,
class_names=class_dict.values(),
norm_colormap=matplotlib.colors.LogNorm()
)
API
plot_confusion_matrix(conf_mat, hide_spines=False, hide_ticks=False, figsize=None, cmap=None, colorbar=False, show_absolute=True, show_normed=False, class_names=None, figure=None, axis=None, fontcolor_threshold=0.5)
Plot a confusion matrix via matplotlib.
Parameters
-
conf_mat
: array-like, shape = [n_classes, n_classes]Confusion matrix from evaluate.confusion matrix.
-
hide_spines
: bool (default: False)Hides axis spines if True.
-
hide_ticks
: bool (default: False)Hides axis ticks if True
-
figsize
: tuple (default: (2.5, 2.5))Height and width of the figure
-
cmap
: matplotlib colormap (default:None
)Uses matplotlib.pyplot.cm.Blues if
None
-
colorbar
: bool (default: False)Shows a colorbar if True
-
show_absolute
: bool (default: True)Shows absolute confusion matrix coefficients if True. At least one of
show_absolute
orshow_normed
must be True. -
show_normed
: bool (default: False)Shows normed confusion matrix coefficients if True. The normed confusion matrix coefficients give the proportion of training examples per class that are assigned the correct label. At least one of
show_absolute
orshow_normed
must be True. -
class_names
: array-like, shape = [n_classes] (default: None)List of class names. If not
None
, ticks will be set to these values. -
figure
: None or Matplotlib figure (default: None)If None will create a new figure.
-
axis
: None or Matplotlib figure axis (default: None)If None will create a new axis.
-
fontcolor_threshold
: Float (default: 0.5)Sets a threshold for choosing black and white font colors for the cells. By default all values larger than 0.5 times the maximum cell value are converted to white, and everything equal or smaller than 0.5 times the maximum cell value are converted to black.
Returns
-
fig, ax
: matplotlib.pyplot subplot objectsFigure and axis elements of the subplot.
Examples
For usage examples, please see https://rasbt.github.io/mlxtend/user_guide/plotting/plot_confusion_matrix/