.. _plotting_api: ================================ 使用 Plotting API 进行开发 ================================ Scikit-learn 定义了一个简单的 API,用于为机器学习创建可视化。这个 API 的关键特性是只需计算一次,并且具有在事后调整可视化的灵活性。本节面向希望开发或维护绘图工具的开发者。对于使用方法,用户应参考 :ref:`用户指南 ` 。 Plotting API 概述 --------------------- 这种逻辑被封装在一个显示对象中,其中计算的数据被存储,并且在 `plot` 方法中进行绘图。显示对象的 `__init__` 方法仅包含创建可视化所需的数据。 `plot` 方法接收仅与可视化相关的参数,例如 matplotlib 轴。 `plot` 方法将 matplotlib 艺术家存储为属性,允许通过显示对象进行样式调整。 `Display` 类应定义一个或两个类方法: `from_estimator` 和 `from_predictions` 。这些方法允许从估计器和一些数据或从真实值和预测值创建 `Display` 对象。在这些类方法使用计算值创建显示对象后,调用显示对象的 `plot` 方法。请注意, `plot` 方法定义了与 matplotlib 相关的属性,例如线条艺术家。这允许在调用 `plot` 方法后进行自定义。 例如, `RocCurveDisplay` 定义了以下方法和属性:: class RocCurveDisplay: def __init__(self, fpr, tpr, roc_auc, estimator_name): ... self.fpr = fpr self.tpr = tpr self.roc_auc = roc_auc self.estimator_name = estimator_name @classmethod def from_estimator(cls, estimator, X, y): # 获取预测值 y_pred = estimator.predict_proba(X)[:, 1] return cls.from_predictions(y, y_pred, estimator.__class__.__name__) @classmethod def from_predictions(cls, y, y_pred, estimator_name): # 从 y 和 y_pred 计算 ROC fpr, tpr, roc_auc = ... viz = RocCurveDisplay(fpr, tpr, roc_auc, estimator_name) return viz.plot() def plot(self, ax=None, name=None, **kwargs): ... self.line_ = ... self.ax_ = ax self.figure_ = ax.figure_ 更多信息请参阅 :ref:`sphx_glr_auto_examples_miscellaneous_plot_roc_curve_visualization_api.py` 和 :ref:`用户指南 ` 。 多轴绘图 ---------- 一些绘图工具,如 :func:`~sklearn.inspection.PartialDependenceDisplay.from_estimator` 和 :class:`~sklearn.inspection.PartialDependenceDisplay` 支持在多个轴上绘图。支持两种不同的场景: 1. 如果传入一个轴列表, `plot` 将检查轴的数量是否与预期一致,然后在这些轴上绘图。 2. 如果传入单个轴,该轴定义了一个空间,用于放置多个轴。在这种情况下,我们建议使用 matplotlib 的 `~matplotlib.gridspec.GridSpecFromSubplotSpec` 来分割空间:: import matplotlib.pyplot as plt from matplotlib.gridspec import GridSpecFromSubplotSpec fig, ax = plt.subplots() gs = GridSpecFromSubplotSpec(2, 2, subplot_spec=ax.get_subplotspec()) ax_top_left = fig.add_subplot(gs[0, 0]) ax_top_right = fig.add_subplot(gs[0, 1]) ax_bottom = fig.add_subplot(gs[1, :]) 默认情况下, `plot` 中的 `ax` 关键字为 `None` 。在这种情况下,会创建单个轴,并使用 gridspec API 创建绘图区域。 例如,:meth:`~sklearn.inspection.PartialDependenceDisplay.from_estimator` 使用此 API 绘制多条线和等高线。定义绘图区域的轴 bounding box 保存在 `bounding_ax_` 属性中。创建的各个轴存储在 `axes_` ndarray 中,对应于轴在网格上的位置。未使用的位置设置为 `None` 。此外,matplotlib Artists 存储在 `lines_` 和 `contours_` 中,其中键是网格上的位置。当传递一个轴列表时, `axes_` 、 `lines_` 和 `contours_` 是一个 1d ndarray,对应于传递的轴列表。