精确率-召回率#

评估分类器输出质量的精确率-召回率指标示例。

当类别非常不平衡时,精确率-召回率是预测成功的有用度量。在信息检索中,精确率是实际返回的相关项中的相关项的比例,而召回率是应返回的所有项中返回的项的比例。这里的“相关性”指的是被正向标记的项,即真正例和假负例。

精确率(\(P\) )定义为真正例(\(T_p\) )的数量除以真正例加上假正例(\(F_p\) )的数量。

\[P = rac{T_p}{T_p+F_p}\]

召回率(\(R\) )定义为真正例(\(T_p\) )的数量除以真正例加上假负例(\(F_n\) )的数量。

\[R = rac{T_p}{T_p + F_n}\]

精确率-召回率曲线显示了不同阈值下精确率和召回率之间的权衡。曲线下面积高表示高召回率和高精确率。高精确率通过在返回结果中很少有假正例来实现,而高召回率通过在相关结果中很少有假负例来实现。高分数表明分类器返回的结果准确(高精确率),并且返回了大多数相关结果(高召回率)。

高召回率但低精确率的系统返回大多数相关项,但返回结果中错误标记的比例很高。高精确率但低召回率的系统则相反,返回的相关项很少,但其预测标签与实际标签相比大多是正确的。理想的高精确率和高召回率系统将返回大多数相关项,并且大多数结果标记正确。

精确率的定义(:math:` rac{T_p}{T_p + F_p}` )表明降低分类器的阈值可能会通过增加返回结果的数量来增加分母。如果之前的阈值设置过高,新结果可能都是真正例,这将提高精确率。如果之前的阈值设置合适或过低,进一步降低阈值将引入假正例,降低精确率。

召回率定义为 :math:` rac{T_p}{T_p+F_n}` ,其中 \(T_p+F_n\) 不依赖于分类器阈值。改变分类器阈值只能改变分子 \(T_p\) 。降低分类器阈值可能会通过增加真正例的数量来提高召回率。也有可能降低阈值会使召回率保持不变,而精确率波动。因此,精确率不一定随着召回率的增加而降低。

在图的阶梯区域可以观察到召回率和精确率之间的关系——在这些阶梯的边缘,阈值的微小变化会显著降低精确率,而召回率的增益很小。

平均精确率 (AP)将这样的图总结为在每个阈值下实现的精确率的加权平均值,前一个阈值的召回率增加用作权重:

:math:` ext{AP} = sum_n (R_n - R_{n-1}) P_n`

其中 \(P_n\)\(R_n\) 是第 n 个阈值下的精确率和召回率。对 \((R_k, P_k)\) 的一对称为*操作点*。

AP 和操作点下的梯形面积(sklearn.metrics.auc )是总结精确率-召回率曲线的常用方法,导致不同的结果。更多信息请参阅 用户指南

精确率-召回率曲线通常用于二分类研究分类器的输出。为了将精确率-召回率曲线和平均精确率扩展到多类或多标签分类,有必要将输出二值化。可以为每个标签绘制一条曲线,但也可以通过将标签指示矩阵的每个元素视为二元预测来绘制精确率-召回率曲线(微平均 )。

在二元分类设置中#

数据集和模型

我们将使用线性支持向量分类器(Linear SVC)来区分两种类型的鸢尾花。

import numpy as np

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

X, y = load_iris(return_X_y=True)

# 添加噪声特征
random_state = np.random.RandomState(0)
n_samples, n_features = X.shape
X = np.concatenate([X, random_state.randn(n_samples, 200 * n_features)], axis=1)

# 限制为前两个类别,并分为训练和测试
X_train, X_test, y_train, y_test = train_test_split(
    X[y < 2], y[y < 2], test_size=0.5, random_state=random_state
)

线性支持向量分类器(Linear SVC)期望每个特征具有相似的取值范围。因此,我们将首先使用 StandardScaler 对数据进行缩放。

from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import LinearSVC

classifier = make_pipeline(StandardScaler(), LinearSVC(random_state=random_state))
classifier.fit(X_train, y_train)
Pipeline(steps=[('standardscaler', StandardScaler()),
                ('linearsvc',
                 LinearSVC(random_state=RandomState(MT19937) at 0xFFFF7C597E40))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


绘制精确率-召回率曲线#

要绘制精确率-召回率曲线,您应该使用 PrecisionRecallDisplay 。实际上,根据您是否已经计算了分类器的预测,有两种可用的方法。

让我们首先在没有分类器预测的情况下绘制精确率-召回率曲线。我们使用 from_estimator ,它在为我们绘制曲线之前计算预测值。

from sklearn.metrics import PrecisionRecallDisplay

display = PrecisionRecallDisplay.from_estimator(
    classifier, X_test, y_test, name="LinearSVC", plot_chance_level=True
)
_ = display.ax_.set_title("2-class Precision-Recall curve")
2-class Precision-Recall curve

如果我们已经获得了模型的估计概率或分数,那么我们可以使用 from_predictions

y_score = classifier.decision_function(X_test)

display = PrecisionRecallDisplay.from_predictions(
    y_test, y_score, name="LinearSVC", plot_chance_level=True
)
_ = display.ax_.set_title("2-class Precision-Recall curve")
2-class Precision-Recall curve

在多标签设置中#

精确率-召回率曲线不支持多标签设置。然而,可以决定如何处理这种情况。我们在下面展示了一个这样的例子。

创建多标签数据,拟合并预测

我们创建了一个多标签数据集,以说明多标签设置中的精确率-召回率。

from sklearn.preprocessing import label_binarize

# 使用 label_binarize 进行多标签设置
Y = label_binarize(y, classes=[0, 1, 2])
n_classes = Y.shape[1]

# 划分为训练集和测试集
X_train, X_test, Y_train, Y_test = train_test_split(
    X, Y, test_size=0.5, random_state=random_state
)

我们使用 OneVsRestClassifier 进行多标签预测。

from sklearn.multiclass import OneVsRestClassifier

classifier = OneVsRestClassifier(
    make_pipeline(StandardScaler(), LinearSVC(random_state=random_state))
)
classifier.fit(X_train, Y_train)
y_score = classifier.decision_function(X_test)

多标签设置中的平均精度得分#

from sklearn.metrics import average_precision_score, precision_recall_curve

# 对于每个类
precision = dict()
recall = dict()
average_precision = dict()
for i in range(n_classes):
    precision[i], recall[i], _ = precision_recall_curve(Y_test[:, i], y_score[:, i])
    average_precision[i] = average_precision_score(Y_test[:, i], y_score[:, i])

# A "micro-average": quantifying score on all classes jointly
#
#
precision["micro"], recall["micro"], _ = precision_recall_curve(
    Y_test.ravel(), y_score.ravel()
)
average_precision["micro"] = average_precision_score(Y_test, y_score, average="micro")

绘制微平均精确率-召回率曲线#

from collections import Counter

display = PrecisionRecallDisplay(
    recall=recall["micro"],
    precision=precision["micro"],
    average_precision=average_precision["micro"],
    prevalence_pos_label=Counter(Y_test.ravel())[1] / Y_test.size,
)
display.plot(plot_chance_level=True)
_ = display.ax_.set_title("Micro-averaged over all classes")
Micro-averaged over all classes

绘制每个类别的精确率-召回率曲线和等值F1曲线#

from itertools import cycle

import matplotlib.pyplot as plt

# 设置绘图细节
colors = cycle(["navy", "turquoise", "darkorange", "cornflowerblue", "teal"])

_, ax = plt.subplots(figsize=(7, 8))

f_scores = np.linspace(0.2, 0.8, num=4)
lines, labels = [], []
for f_score in f_scores:
    x = np.linspace(0.01, 1)
    y = f_score * x / (2 * x - f_score)
    (l,) = plt.plot(x[y >= 0], y[y >= 0], color="gray", alpha=0.2)
    plt.annotate("f1={0:0.1f}".format(f_score), xy=(0.9, y[45] + 0.02))

display = PrecisionRecallDisplay(
    recall=recall["micro"],
    precision=precision["micro"],
    average_precision=average_precision["micro"],
)
display.plot(ax=ax, name="Micro-average precision-recall", color="gold")

for i, color in zip(range(n_classes), colors):
    display = PrecisionRecallDisplay(
        recall=recall[i],
        precision=precision[i],
        average_precision=average_precision[i],
    )
    display.plot(ax=ax, name=f"Precision-recall for class {i}", color=color)

# 添加等F1曲线的图例
handles, labels = display.ax_.get_legend_handles_labels()
handles.extend([l])
labels.extend(["iso-f1 curves"])
# 设置图例和坐标轴
ax.legend(handles=handles, labels=labels, loc="best")
ax.set_title("Extension of Precision-Recall curve to multi-class")

plt.show()
Extension of Precision-Recall curve to multi-class

Total running time of the script: (0 minutes 0.174 seconds)

Related examples

使用显示对象进行可视化

使用显示对象进行可视化

网格搜索与交叉验证的自定义重拟合策略

网格搜索与交叉验证的自定义重拟合策略

多分类接收者操作特性(ROC)

多分类接收者操作特性(ROC)

L2 正则化对岭回归系数的影响

L2 正则化对岭回归系数的影响

Gallery generated by Sphinx-Gallery