__sklearn_is_fitted__ 作为开发者 API#

__sklearn_is_fitted__ 方法是 scikit-learn 中用于检查估计器对象是否已被拟合的约定。这个方法通常在基于 scikit-learn 基类(如 BaseEstimator 或其子类)构建的自定义估计器类中实现。

开发者应在除 fit 之外的所有方法的开头使用 check_is_fitted 。如果他们需要自定义或加速检查过程,可以按照下面的示例实现 __sklearn_is_fitted__ 方法。

在这个示例中,自定义估计器展示了 __sklearn_is_fitted__ 方法和 check_is_fitted 实用函数作为开发者 API 的用法。 __sklearn_is_fitted__ 方法通过验证 _is_fitted 属性的存在来检查是否已拟合。

一个实现简单分类器的自定义估计器示例#

此代码片段定义了一个名为 CustomEstimator 的自定义估计器类, 该类继承自 scikit-learn 的 BaseEstimatorClassifierMixin 类, 并展示了 __sklearn_is_fitted__ 方法和 check_is_fitted 实用函数的用法。

# 作者:scikit-learn 开发者
# SPDX-License-Identifier: BSD-3-Clause

from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils.validation import check_is_fitted


class CustomEstimator(BaseEstimator, ClassifierMixin):
    def __init__(self, parameter=1):
        self.parameter = parameter

    def fit(self, X, y):
        """
        将估计器拟合到训练数据。
        """
        self.classes_ = sorted(set(y))
        # 自定义属性以跟踪估计器是否已拟合
        self._is_fitted = True
        return self

    def predict(self, X):
        """进行预测

如果估计器未拟合,则引发 NotFittedError
"""
        check_is_fitted(self)
        # 执行预测逻辑
        predictions = [self.classes_[0]] * len(X)
        return predictions

    def score(self, X, y):
        """计算分数

如果估算器未拟合,则引发 NotFittedError
"""
        check_is_fitted(self)
        # 执行评分逻辑
        return 0.5

    def __sklearn_is_fitted__(self):
        """
        检查拟合状态并返回布尔值。
        """
        return hasattr(self, "_is_fitted") and self._is_fitted

Related examples

归纳聚类

归纳聚类

文本文档的外存分类

文本文档的外存分类

离散数据结构上的高斯过程

离散数据结构上的高斯过程

元数据路由

元数据路由

Gallery generated by Sphinx-Gallery