Note
Go to the end to download the full example code. or to run this example in your browser via Binder
使用显示对象进行可视化#
在这个例子中,我们将直接从各自的度量构建显示对象,ConfusionMatrixDisplay
、RocCurveDisplay
和 PrecisionRecallDisplay
。当模型的预测结果已经计算出来或计算代价较高时,这是使用相应绘图函数的替代方法。请注意,这是高级用法,一般情况下我们推荐使用相应的绘图函数。
加载数据并训练模型#
在此示例中,我们从 OpenML <https://www.openml.org/d/1464>
加载一个血液输送服务中心的数据集。这是一个二元分类问题,目标是判断个体是否献血。然后将数据分为训练集和测试集,并使用训练集拟合逻辑回归模型。
from sklearn.datasets import fetch_openml
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
X, y = fetch_openml(data_id=1464, return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y)
clf = make_pipeline(StandardScaler(), LogisticRegression(random_state=0))
clf.fit(X_train, y_train)
使用拟合的模型,我们计算模型在测试数据集上的预测。这些预测用于计算混淆矩阵,并使用 ConfusionMatrixDisplay
绘制。
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
y_pred = clf.predict(X_test)
cm = confusion_matrix(y_test, y_pred)
cm_display = ConfusionMatrixDisplay(cm).plot()
ROC曲线需要估计器的概率或非阈值决策值。由于逻辑回归提供了决策函数,我们将使用它来绘制ROC曲线:
from sklearn.metrics import RocCurveDisplay, roc_curve
y_score = clf.decision_function(X_test)
fpr, tpr, _ = roc_curve(y_test, y_score, pos_label=clf.classes_[1])
roc_display = RocCurveDisplay(fpr=fpr, tpr=tpr).plot()
/app/scikit-learn-main-origin/sklearn/metrics/_plot/roc_curve.py:163: UserWarning:
No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
同样,可以使用前面部分的 y_score
绘制精确率-召回率曲线。
from sklearn.metrics import PrecisionRecallDisplay, precision_recall_curve
prec, recall, _ = precision_recall_curve(y_test, y_score, pos_label=clf.classes_[1])
pr_display = PrecisionRecallDisplay(precision=prec, recall=recall).plot()
将显示对象组合到一个单一的图中
显示对象存储作为参数传递的计算值。这使得可以使用matplotlib的API轻松地组合可视化。在下面的示例中,我们将显示对象并排放置在一行中。
import matplotlib.pyplot as plt
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))
roc_display.plot(ax=ax1)
pr_display.plot(ax=ax2)
plt.show()
/app/scikit-learn-main-origin/sklearn/metrics/_plot/roc_curve.py:163: UserWarning:
No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
Total running time of the script: (0 minutes 0.164 seconds)
Related examples
精确率-召回率
多分类接收者操作特性(ROC)
检测错误权衡(DET)曲线
调整决策阈值以适应成本敏感学习