.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/model_selection/plot_precision_recall.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_model_selection_plot_precision_recall.py: ================ 精确率-召回率 ================ 评估分类器输出质量的精确率-召回率指标示例。 当类别非常不平衡时,精确率-召回率是预测成功的有用度量。在信息检索中,精确率是实际返回的相关项中的相关项的比例,而召回率是应返回的所有项中返回的项的比例。这里的“相关性”指的是被正向标记的项,即真正例和假负例。 精确率(:math:`P` )定义为真正例(:math:`T_p` )的数量除以真正例加上假正例(:math:`F_p` )的数量。 .. math:: P = rac{T_p}{T_p+F_p} 召回率(:math:`R` )定义为真正例(:math:`T_p` )的数量除以真正例加上假负例(:math:`F_n` )的数量。 .. math:: R = rac{T_p}{T_p + F_n} 精确率-召回率曲线显示了不同阈值下精确率和召回率之间的权衡。曲线下面积高表示高召回率和高精确率。高精确率通过在返回结果中很少有假正例来实现,而高召回率通过在相关结果中很少有假负例来实现。高分数表明分类器返回的结果准确(高精确率),并且返回了大多数相关结果(高召回率)。 高召回率但低精确率的系统返回大多数相关项,但返回结果中错误标记的比例很高。高精确率但低召回率的系统则相反,返回的相关项很少,但其预测标签与实际标签相比大多是正确的。理想的高精确率和高召回率系统将返回大多数相关项,并且大多数结果标记正确。 精确率的定义(:math:` rac{T_p}{T_p + F_p}` )表明降低分类器的阈值可能会通过增加返回结果的数量来增加分母。如果之前的阈值设置过高,新结果可能都是真正例,这将提高精确率。如果之前的阈值设置合适或过低,进一步降低阈值将引入假正例,降低精确率。 召回率定义为 :math:` rac{T_p}{T_p+F_n}` ,其中 :math:`T_p+F_n` 不依赖于分类器阈值。改变分类器阈值只能改变分子 :math:`T_p` 。降低分类器阈值可能会通过增加真正例的数量来提高召回率。也有可能降低阈值会使召回率保持不变,而精确率波动。因此,精确率不一定随着召回率的增加而降低。 在图的阶梯区域可以观察到召回率和精确率之间的关系——在这些阶梯的边缘,阈值的微小变化会显著降低精确率,而召回率的增益很小。 **平均精确率** (AP)将这样的图总结为在每个阈值下实现的精确率的加权平均值,前一个阈值的召回率增加用作权重: :math:` ext{AP} = \sum_n (R_n - R_{n-1}) P_n` 其中 :math:`P_n` 和 :math:`R_n` 是第 n 个阈值下的精确率和召回率。对 :math:`(R_k, P_k)` 的一对称为*操作点*。 AP 和操作点下的梯形面积(:func:`sklearn.metrics.auc` )是总结精确率-召回率曲线的常用方法,导致不同的结果。更多信息请参阅 :ref:`用户指南 ` 。 精确率-召回率曲线通常用于二分类研究分类器的输出。为了将精确率-召回率曲线和平均精确率扩展到多类或多标签分类,有必要将输出二值化。可以为每个标签绘制一条曲线,但也可以通过将标签指示矩阵的每个元素视为二元预测来绘制精确率-召回率曲线(:ref:`微平均 ` )。 .. note:: 另请参阅 :func:`sklearn.metrics.average_precision_score` , :func:`sklearn.metrics.recall_score` , :func:`sklearn.metrics.precision_score` , :func:`sklearn.metrics.f1_score` .. GENERATED FROM PYTHON SOURCE LINES 49-55 在二元分类设置中 --------------------------------- 数据集和模型 我们将使用线性支持向量分类器(Linear SVC)来区分两种类型的鸢尾花。 .. GENERATED FROM PYTHON SOURCE LINES 55-72 .. code-block:: Python import numpy as np from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split X, y = load_iris(return_X_y=True) # 添加噪声特征 random_state = np.random.RandomState(0) n_samples, n_features = X.shape X = np.concatenate([X, random_state.randn(n_samples, 200 * n_features)], axis=1) # 限制为前两个类别,并分为训练和测试 X_train, X_test, y_train, y_test = train_test_split( X[y < 2], y[y < 2], test_size=0.5, random_state=random_state ) .. GENERATED FROM PYTHON SOURCE LINES 73-74 线性支持向量分类器(Linear SVC)期望每个特征具有相似的取值范围。因此,我们将首先使用 :class:`~sklearn.preprocessing.StandardScaler` 对数据进行缩放。 .. GENERATED FROM PYTHON SOURCE LINES 74-82 .. code-block:: Python from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler from sklearn.svm import LinearSVC classifier = make_pipeline(StandardScaler(), LinearSVC(random_state=random_state)) classifier.fit(X_train, y_train) .. raw:: html
Pipeline(steps=[('standardscaler', StandardScaler()),
                    ('linearsvc',
                     LinearSVC(random_state=RandomState(MT19937) at 0xFFFF7C597E40))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 83-89 绘制精确率-召回率曲线 ............................... 要绘制精确率-召回率曲线,您应该使用 :class:`~sklearn.metrics.PrecisionRecallDisplay` 。实际上,根据您是否已经计算了分类器的预测,有两种可用的方法。 让我们首先在没有分类器预测的情况下绘制精确率-召回率曲线。我们使用 :func:`~sklearn.metrics.PrecisionRecallDisplay.from_estimator` ,它在为我们绘制曲线之前计算预测值。 .. GENERATED FROM PYTHON SOURCE LINES 89-96 .. code-block:: Python from sklearn.metrics import PrecisionRecallDisplay display = PrecisionRecallDisplay.from_estimator( classifier, X_test, y_test, name="LinearSVC", plot_chance_level=True ) _ = display.ax_.set_title("2-class Precision-Recall curve") .. image-sg:: /auto_examples/model_selection/images/sphx_glr_plot_precision_recall_001.png :alt: 2-class Precision-Recall curve :srcset: /auto_examples/model_selection/images/sphx_glr_plot_precision_recall_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 97-98 如果我们已经获得了模型的估计概率或分数,那么我们可以使用 :func:`~sklearn.metrics.PrecisionRecallDisplay.from_predictions` 。 .. GENERATED FROM PYTHON SOURCE LINES 98-106 .. code-block:: Python y_score = classifier.decision_function(X_test) display = PrecisionRecallDisplay.from_predictions( y_test, y_score, name="LinearSVC", plot_chance_level=True ) _ = display.ax_.set_title("2-class Precision-Recall curve") .. image-sg:: /auto_examples/model_selection/images/sphx_glr_plot_precision_recall_002.png :alt: 2-class Precision-Recall curve :srcset: /auto_examples/model_selection/images/sphx_glr_plot_precision_recall_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 107-115 在多标签设置中 ----------------- 精确率-召回率曲线不支持多标签设置。然而,可以决定如何处理这种情况。我们在下面展示了一个这样的例子。 创建多标签数据,拟合并预测 我们创建了一个多标签数据集,以说明多标签设置中的精确率-召回率。 .. GENERATED FROM PYTHON SOURCE LINES 115-127 .. code-block:: Python from sklearn.preprocessing import label_binarize # 使用 label_binarize 进行多标签设置 Y = label_binarize(y, classes=[0, 1, 2]) n_classes = Y.shape[1] # 划分为训练集和测试集 X_train, X_test, Y_train, Y_test = train_test_split( X, Y, test_size=0.5, random_state=random_state ) .. GENERATED FROM PYTHON SOURCE LINES 128-129 我们使用 :class:`~sklearn.multiclass.OneVsRestClassifier` 进行多标签预测。 .. GENERATED FROM PYTHON SOURCE LINES 129-139 .. code-block:: Python from sklearn.multiclass import OneVsRestClassifier classifier = OneVsRestClassifier( make_pipeline(StandardScaler(), LinearSVC(random_state=random_state)) ) classifier.fit(X_train, Y_train) y_score = classifier.decision_function(X_test) .. GENERATED FROM PYTHON SOURCE LINES 140-142 多标签设置中的平均精度得分 ................................................... .. GENERATED FROM PYTHON SOURCE LINES 142-161 .. code-block:: Python from sklearn.metrics import average_precision_score, precision_recall_curve # 对于每个类 precision = dict() recall = dict() average_precision = dict() for i in range(n_classes): precision[i], recall[i], _ = precision_recall_curve(Y_test[:, i], y_score[:, i]) average_precision[i] = average_precision_score(Y_test[:, i], y_score[:, i]) # A "micro-average": quantifying score on all classes jointly # # precision["micro"], recall["micro"], _ = precision_recall_curve( Y_test.ravel(), y_score.ravel() ) average_precision["micro"] = average_precision_score(Y_test, y_score, average="micro") .. GENERATED FROM PYTHON SOURCE LINES 162-164 绘制微平均精确率-召回率曲线 .............................................. .. GENERATED FROM PYTHON SOURCE LINES 164-176 .. code-block:: Python from collections import Counter display = PrecisionRecallDisplay( recall=recall["micro"], precision=precision["micro"], average_precision=average_precision["micro"], prevalence_pos_label=Counter(Y_test.ravel())[1] / Y_test.size, ) display.plot(plot_chance_level=True) _ = display.ax_.set_title("Micro-averaged over all classes") .. image-sg:: /auto_examples/model_selection/images/sphx_glr_plot_precision_recall_003.png :alt: Micro-averaged over all classes :srcset: /auto_examples/model_selection/images/sphx_glr_plot_precision_recall_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 177-179 绘制每个类别的精确率-召回率曲线和等值F1曲线 ............................................................ .. GENERATED FROM PYTHON SOURCE LINES 179-220 .. code-block:: Python from itertools import cycle import matplotlib.pyplot as plt # 设置绘图细节 colors = cycle(["navy", "turquoise", "darkorange", "cornflowerblue", "teal"]) _, ax = plt.subplots(figsize=(7, 8)) f_scores = np.linspace(0.2, 0.8, num=4) lines, labels = [], [] for f_score in f_scores: x = np.linspace(0.01, 1) y = f_score * x / (2 * x - f_score) (l,) = plt.plot(x[y >= 0], y[y >= 0], color="gray", alpha=0.2) plt.annotate("f1={0:0.1f}".format(f_score), xy=(0.9, y[45] + 0.02)) display = PrecisionRecallDisplay( recall=recall["micro"], precision=precision["micro"], average_precision=average_precision["micro"], ) display.plot(ax=ax, name="Micro-average precision-recall", color="gold") for i, color in zip(range(n_classes), colors): display = PrecisionRecallDisplay( recall=recall[i], precision=precision[i], average_precision=average_precision[i], ) display.plot(ax=ax, name=f"Precision-recall for class {i}", color=color) # 添加等F1曲线的图例 handles, labels = display.ax_.get_legend_handles_labels() handles.extend([l]) labels.extend(["iso-f1 curves"]) # 设置图例和坐标轴 ax.legend(handles=handles, labels=labels, loc="best") ax.set_title("Extension of Precision-Recall curve to multi-class") plt.show() .. image-sg:: /auto_examples/model_selection/images/sphx_glr_plot_precision_recall_004.png :alt: Extension of Precision-Recall curve to multi-class :srcset: /auto_examples/model_selection/images/sphx_glr_plot_precision_recall_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.174 seconds) .. _sphx_glr_download_auto_examples_model_selection_plot_precision_recall.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/main?urlpath=lab/tree/notebooks/auto_examples/model_selection/plot_precision_recall.ipynb :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_precision_recall.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_precision_recall.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_precision_recall.zip ` .. include:: plot_precision_recall.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_