开发 scikit-learn 估计器#
无论你是提议将一个估计器包含在 scikit-learn 中, 开发一个与 scikit-learn 兼容的独立包,还是 为你的项目实现自定义组件,本章节 详细说明了如何开发能够与 scikit-learn 管道和模型选择工具安全交互的对象。
scikit-learn 对象的 API#
为了拥有统一的 API,我们尝试为所有对象提供一个共同的基本 API。 此外,为了避免框架代码的泛滥,我们尝试采用简单的约定,并尽可能减少 一个对象必须实现的方法数量。
scikit-learn API 的元素在 术语表 中有更明确的描述。
不同的对象#
scikit-learn 中的主要对象包括(一个类可以实现多个接口):
- 估计器(Estimator):
基础对象,实现了从数据中学习的
fit
方法,可以是:estimator = estimator.fit(data, targets)
或者:
estimator = estimator.fit(data)
- 预测器(Predictor):
对于监督学习或某些无监督问题,实现了:
prediction = predictor.predict(data)
分类算法通常还提供一种量化预测确定性的方法,使用
decision_function
或predict_proba
probability = predictor.predict_proba(data)
- 转换器(Transformer):
用于以监督或无监督方式修改数据(例如通过添加、更改或删除列,但不包括添加或删除行)。实现了:
new_data = transformer.transform(data)
当拟合和转换可以更高效地一起执行而不是分开执行时,实现了:
new_data = transformer.fit_transform(data)
- 模型(Model):
一个能够给出 拟合优度 度量或未见数据似然性的模型,实现了(值越高越好):
score = model.score(data)
估计器#
API 有一个主要的对象:估计器。估计器是一个基于某些训练数据拟合模型并能够推断新数据属性的对象。它可以是一个分类器或回归器。所有估计器都实现了 fit 方法:
estimator.fit(X, y)
所有内置估计器还有一个 set_params
方法,用于设置与数据无关的参数(覆盖之前传递给 __init__
的参数值)。
主 scikit-learn 代码库中的所有估计器都应该继承自 sklearn.base.BaseEstimator
。
实例化#
这涉及对象的创建。对象的 __init__
方法可能接受决定估计器行为的常量作为参数(如 SVM 中的 C 常数)。然而,它不应该接受实际的训练数据作为参数,因为这是留给 fit()
方法的:
clf2 = SVC(C=2.3)
clf3 = SVC([[1, 2], [2, 3]], [-1, 1]) # 错误!
``__init__`` 接受的参数都应该是带有默认值的关键字参数。换句话说,用户应该能够在不传递任何参数的情况下实例化一个估计器。这些参数都应该对应于描述模型或估计器试图解决的优化问题的超参数。这些初始参数(或参数)始终被估计器记住。
还要注意,它们不应该在“属性”部分下记录,而应该在该估计器的“参数”部分下记录。
此外,** __init__
接受的每个关键字参数都应该对应于实例上的一个属性**。Scikit-learn 依赖于此来在执行模型选择时找到要设置的相关属性。
总结一下,一个 __init__
应该看起来像这样:
def __init__(self, param1=1, param2=2):
self.param1 = param1
self.param2 = param2
不应该有逻辑,甚至不应该有输入验证,参数也不应该被修改。
相应的逻辑应该放在使用参数的地方,通常在 fit
中。
以下是错误的:
def __init__(self, param1=1, param2=2, param3=3):
# 错误:参数不应该被修改
if param1 > 1:
param2 += 1
self.param1 = param1
# 错误:对象的属性应该与构造函数中的参数名称完全一致
self.param3 = param2
推迟验证的原因是同样的验证必须在 set_params
中执行,
这在 GridSearchCV
等算法中使用。
拟合#
接下来你可能想要做的是在模型中估计一些参数。这通过 fit()
方法实现。
fit()
方法接受训练数据作为参数,对于无监督学习可以是一个数组,
对于监督学习可以是两个数组。
注意模型是通过 X
和 y
进行拟合的,但对象不持有对 X
和 y
的引用。
然而,也有一些例外,比如在预计算内核的情况下,这些数据必须存储以供预测方法使用。
参数 |
|
---|---|
X |
形状为 (n_samples, n_features) 的类数组对象 |
y |
形状为 (n_samples,) 的类数组对象 |
kwargs |
可选的数据依赖参数 |
X.shape[0]
应该与y.shape[0]
相同。如果这个要求没有满足,
- 应该抛出一个类型为
ValueError
的异常。 y
在无监督学习的情况下可能会被忽略。然而,为了使估计器能够作为可以混合监督和无监督变换器的管道的一部分使用,即使是无监督估计器也需要接受一个在第二个位置的y=None
关键字参数,该参数仅被估计器忽略。出于同样的原因,如果实现了fit_predict
、fit_transform
、score
和partial_fit
方法,它们需要接受第二个位置的y
参数。
该方法应返回对象( self
)。这种模式对于在 IPython 会话中实现快速的一行代码非常有用,例如:
y_predicted = SVC(C=100).fit(X_train, y_train).predict(X_test)
根据算法的性质, fit
有时也可以接受额外的关键字参数。然而,任何在访问数据之前可以赋值的参数都应该是 __init__
关键字参数。fit 参数应仅限于直接依赖于数据的变量。例如,从数据矩阵 X
预计算的 Gram 矩阵或亲和矩阵是数据依赖的。一个容差停止准则 tol
不是直接数据依赖的(尽管根据某些评分函数的最佳值可能是)。
当调用 fit
时,任何之前的 fit
调用都应该被忽略。通常,调用 estimator.fit(X1)
然后调用 estimator.fit(X2)
应该与仅调用 estimator.fit(X2)
相同。然而,在实践中,当 fit
依赖于某些随机过程时,这可能不成立,参见 random_state 。另一个例外是当支持它的估计器的超参数 warm_start
设置为 True
时。 warm_start=True
意味着重用估计器可训练参数的先前状态,而不是使用默认的初始化策略。
估计属性#
从数据中估计得到的属性必须始终以尾部下划线结尾,例如,某些回归估计器的系数在调用 fit
后会存储在 coef_
属性中。
预计在第二次调用 fit
时会覆盖估计的属性。
可选参数#
在迭代算法中,迭代次数应通过名为 n_iter
的整数来指定。
通用属性#
- 期望表格输入的估计器应在
fit
时设置n_features_in_
属性,以指示估计器在后续调用predict
或transform
时预期的特征数量。详情请参见 `SLEP010
<https://scikit-learn-enhancement-proposals.readthedocs.io/en/latest/slep010/proposal.html>`_ 。
自定义估计器#
如果你想实现一个与 scikit-learn 兼容的新估计器,无论是仅为个人使用还是为了向 scikit-learn 贡献代码,除了上述概述的 scikit-learn API 之外,你还应该了解 scikit-learn 的几个内部机制。你可以通过在实例上运行 check_estimator
来检查你的估计器是否符合 scikit-learn 接口和标准。 parametrize_with_checks
pytest 装饰器也可以使用(详见其文档字符串及与 pytest
的可能交互):
>>> from sklearn.utils.estimator_checks import check_estimator
>>> from sklearn.svm import LinearSVC
>>> check_estimator(LinearSVC()) # 通过
使一个类与 scikit-learn 估计器接口兼容的主要动机可能是你想将其与模型评估和选择工具(如 model_selection.GridSearchCV
和 pipeline.Pipeline
)一起使用。
在详细介绍所需的接口之前,我们描述两种更容易实现正确接口的方法。
>>> class TemplateClassifier(ClassifierMixin, BaseEstimator):
...
... def __init__(self, demo_param='demo'):
... self.demo_param = demo_param
...
... def fit(self, X, y):
...
... # 检查 X 和 y 是否具有正确的形状
... X, y = check_X_y(X, y)
... # 存储在拟合过程中看到的类别
... self.classes_ = unique_labels(y)
...
... self.X_ = X
... self.y_ = y
... # 返回分类器
... return self
...
... def predict(self, X):
...
... # 检查是否已调用 fit
... check_is_fitted(self)
...
... # 输入验证
... X = check_array(X)
...
... closest = np.argmin(euclidean_distances(X, self.X_), axis=1)
... return self.y_[closest]
get_params 和 set_params#
- 所有 scikit-learn 估计器都有
get_params
和set_params
函数。 get_params
函数不接受任何参数,并返回一个包含估计器__init__
参数及其值的字典。
它必须接受一个关键字参数 deep
,该参数接收一个布尔值,用于确定方法是否应返回子估计器的参数(对于大多数估计器,这可以忽略)。 deep
的默认值应为 True
。例如,考虑以下估计器:
>>> from sklearn.base import BaseEstimator
>>> from sklearn.linear_model import LogisticRegression
>>> class MyEstimator(BaseEstimator):
... def __init__(self, subestimator=None, my_extra_param="random"):
... self.subestimator = subestimator
... self.my_extra_param = my_extra_param
参数 deep
将控制是否报告 subestimator
的参数。因此,当 deep=True
时,输出将是:
>>> my_estimator = MyEstimator(subestimator=LogisticRegression())
>>> for param, value in my_estimator.get_params(deep=True).items():
... print(f"{param} -> {value}")
my_extra_param -> random
subestimator__C -> 1.0
subestimator__class_weight -> None
subestimator__dual -> False
subestimator__fit_intercept -> True
subestimator__intercept_scaling -> 1
subestimator__l1_ratio -> None
subestimator__max_iter -> 100
subestimator__multi_class -> deprecated
subestimator__n_jobs -> None
subestimator__penalty -> l2
subestimator__random_state -> None
subestimator__solver -> lbfgs
subestimator__tol -> 0.0001
subestimator__verbose -> 0
subestimator__warm_start -> False
subestimator -> LogisticRegression()
通常, subestimator
有一个名称(例如,在 Pipeline
对象中的命名步骤),在这种情况下,键应该变成 <name>__C
、 <name>__class_weight
等。
而当 deep=False
时,输出将是:
>>> for param, value in my_estimator.get_params(deep=False).items():
... print(f"{param} -> {value}")
my_extra_param -> random
subestimator -> LogisticRegression()
另一方面, set_params
接受 __init__
的参数作为关键字参数,将它们解包成 'parameter': value
形式的字典,并使用这个字典设置估计器的参数。返回值必须是估计器本身。
虽然 get_params
机制不是必需的(参见下面的 克隆 ),但 set_params
函数是必要的,因为它用于在网格搜索期间设置参数。
实现这些函数并获得合理的 __repr__
方法的最简单方法是继承自 sklearn.base.BaseEstimator
。如果你不想让你的代码依赖于 scikit-learn,实现接口的最简单方法是:
def get_params(self, deep=True):
# 假设这个估计器有参数 "alpha" 和 "recursive"
return {"alpha": self.alpha, "recursive": self.recursive}
def set_params(self, **parameters):
for parameter, value in parameters.items():
setattr(self, parameter, value)
return self
参数和初始化#
正如 model_selection.GridSearchCV
使用 set_params
将参数设置应用于估计器,
调用 set_params
与通过 __init__
方法设置参数的效果相同是至关重要的。
实现这一点的最简单且推荐的方法是
不在 __init__
中进行任何参数验证。
所有关于估计器参数的逻辑,
例如将字符串参数转换为函数,应在 fit
中完成。
同时,预期带有尾部下划线 _
的参数**不会在** __init__
方法中设置。
所有且仅由 fit
设置的公共属性带有尾部下划线 _
。
因此,参数是否带有尾部下划线 _
的存在用于检查估计器是否已被拟合。
克隆#
为了与 model_selection
模块一起使用,
估计器必须支持 base.clone
函数以复制估计器。
这可以通过提供 get_params
方法来实现。
如果 get_params
存在,那么 clone(estimator)
将是
type(estimator)
的一个实例,其上已调用set_params
并传入estimator.get_params()
结果的克隆。
未提供此方法的对象将在传递 safe=False
给 clone
时进行深拷贝
(使用 Python 标准函数 copy.deepcopy
)。
- 估计器可以通过定义
__sklearn_clone__
方法来自定义base.clone
的行为。 __sklearn_clone__
必须返回估计器的实例。
当在估计器上调用 base.clone
时, __sklearn_clone__
对于需要保留某些状态的估计器很有用。
例如,可以定义一个用于转换器的冻结元估计器,如下所示:
class FrozenTransformer(BaseEstimator):
def __init__(self, fitted_transformer):
self.fitted_transformer = fitted_transformer
def __getattr__(self, name):
# `fitted_transformer` 's attributes are now accessible
return getattr(self.fitted_transformer, name)
def __sklearn_clone__(self):
return self
def fit(self, X, y):
# Fitting does not change the state of the estimator
return self
def fit_transform(self, X, y=None):
# fit_transform only transforms the data
return self.fitted_transformer.transform(X, y)
管道兼容性#
为了使估计器能够与 pipeline.Pipeline
一起使用,除了最后一步之外,它需要提供一个 fit
或 fit_transform
函数。
为了能够在训练集之外的数据上评估管道,它还需要提供一个 transform
函数。
对于管道的最后一步,除了它有一个 fit
函数之外,没有特殊要求。所有 fit
和 fit_transform
函数必须
接受参数 X, y
,即使 y 未被使用。同样,为了使 score
可用,管道的最后一步需要有一个接受可选 y
的 score
函数。
估计器类型#
某些常见功能取决于传递的估计器类型。
例如,model_selection.GridSearchCV
和 model_selection.cross_val_score
中的交叉验证默认在分类器上进行分层,但在其他情况下则不然。
类似地,对于平均精度的评分器,如果预测是连续的,则需要对分类器调用 decision_function
,对回归器调用 predict
。
这种分类器和回归器之间的区别是通过 _estimator_type
属性实现的,该属性接受一个字符串值。
对于分类器,它应该是 "classifier"
,对于回归器,它应该是 "regressor"
。
回归器和用于聚类方法的 "clusterer"
,以按预期工作。继承自 ClassifierMixin
、 RegressorMixin
或 ClusterMixin
将自动设置该属性。当一个元估计器需要区分估计器类型时,应使用 base.is_classifier
等辅助函数,而不是直接检查 _estimator_type
。
特定模型#
分类器应接受 y
(目标)参数传递给 fit
方法,这些参数可以是字符串或整数的序列(列表、数组)。它们不应假设类标签是连续的整数范围;相反,它们应在一个 classes_
属性或属性中存储类标签列表。此属性中类标签的顺序应与 predict_proba
、 predict_log_proba
和 decision_function
返回值的顺序一致。最简单的方法是在 fit
中添加如下代码:
self.classes_, y = np.unique(y, return_inverse=True)
这会返回一个新的 y
,其中包含类索引而非标签,范围在 [0, n_classes
) 内。
分类器的 predict
方法应返回包含 classes_
中类标签的数组。在实现了 decision_function
的分类器中,可以通过以下方式实现:
def predict(self, X):
D = self.decision_function(X)
return self.classes_[np.argmax(D, axis=1)]
在线性模型中,系数存储在一个名为 coef_
的数组中,独立项存储在 intercept_
中。 sklearn.linear_model._base
包含一些实现常见线性模型模式的基类和混合类。
multiclass
模块包含处理多类和多标签问题的有用函数。
HTML 表示的开发者 API#
Warning
该 HTML 表示 API 是实验性的,API 可能会发生变化。
继承自 BaseEstimator
的估计器显示
在交互式编程环境(如 Jupyter 笔记本)中,它们会以 HTML 形式呈现自身。例如,我们可以显示这个 HTML 图表:
from sklearn.base import BaseEstimator
BaseEstimator()
通过在估计器实例上调用函数 estimator_html_repr
可以获得原始的 HTML 表示。
要自定义链接到估计器文档的 URL(即点击“?”图标时),请重写 _doc_link_module
和 _doc_link_template
属性。此外,您可以提供一个 _doc_link_url_param_generator
方法。将 _doc_link_module
设置为包含您的估计器的(顶级)模块的名称。如果该值与顶级模块名称不匹配,HTML 表示将不包含指向文档的链接。对于 scikit-learn 估计器,这被设置为 "sklearn"
。
_doc_link_template
用于构建最终的 URL。默认情况下,它可以包含两个变量:estimator_module
(包含估计器的模块的全名)和estimator_name
(估计器的类名)。如果您需要更多变量,应该实现_doc_link_url_param_generator
方法,该方法应返回一个包含变量及其值的字典。此字典将用于渲染_doc_link_template
。
编码指南#
以下是关于如何为 scikit-learn 编写新代码以及可能适用于外部项目的一些指南。当然,存在特殊情况,这些规则会有例外。然而,在提交新代码时遵循这些规则使得审查更容易,从而新代码可以更快地集成。
统一格式的代码使得共享代码所有权更加容易。scikit-learn 项目试图紧密遵循 PEP8 中详述的官方 Python 指南, 详细说明代码应如何格式化和缩进。请阅读并遵循这些规则。
此外,我们添加以下准则:
在非类名中使用下划线分隔单词:
n_samples
而不是nsamples
。避免在一行中使用多个语句。在控制流语句(
if
/for
)后倾向于换行。对于 scikit-learn 内部的引用,使用相对导入。
单元测试是前述规则的例外;它们应使用绝对导入,与客户端代码完全一致。一个推论是,如果
sklearn.foo
导出一个在sklearn.foo.bar.baz
中实现的类或函数,测试应从sklearn.foo
导入它。请不要在任何情况下使用
import *
。根据 官方 Python 建议 ,这被认为是有害的。它使代码更难阅读,因为符号的来源不再明确引用,最重要的是,它阻止了使用像 pyflakes 这样的静态分析工具来自动发现 scikit-learn 中的错误。在所有文档字符串中使用 numpy 文档字符串标准 。
我们喜欢的一个代码示例可以在这里找到 这里 。
输入验证#
模块 sklearn.utils
包含用于进行输入验证和转换的各种函数。有时, np.asarray
足以进行验证;不要使用 np.asanyarray
或 np.atleast_2d
,因为这些允许 NumPy 的 np.matrix
通过,它具有不同的 API(例如, *
在 np.matrix
上表示点积,而在 np.ndarray
上表示哈达玛积)。
在其他情况下,确保在任何类似数组的参数上调用 check_array
。
传递给 scikit-learn API 函数。确切的参数使用主要取决于是否以及需要接受哪些 scipy.sparse
矩阵。
更多信息,请参阅 开发者工具 页面。
随机数#
如果你的代码依赖于随机数生成器,请不要使用 numpy.random.random()
或类似函数。为了确保错误检查中的可重复性,该函数应接受一个关键字参数 random_state
,并使用它来构造一个 numpy.random.RandomState
对象。请参阅 开发者工具 中的 sklearn.utils.check_random_state
。
以下是一个使用上述部分指南的简单代码示例:
from sklearn.utils import check_array, check_random_state
def choose_random_sample(X, random_state=0):
"""从 X 中随机选择一个点。
参数
----------
X : 形状为 (n_samples, n_features) 的类数组对象
表示数据的数组。
random_state : int 或 RandomState 实例,默认=0
用于选择随机样本的伪随机数生成器的种子。在多次函数调用中传递 int 以获得可重现的输出。
参见 :term:`术语表 <random_state>` 。
返回
-------
x : 形状为 (n_features,) 的 ndarray
从 X 中随机选择的点。
"""
X = check_array(X)
random_state = check_random_state(random_state)
i = random_state.randint(X.shape[0])
return X[i]
如果你在估计器中而不是独立函数中使用随机性,则需要遵循一些额外的指南。
首先,估计器应在其 __init__
方法中接受一个 random_state
参数,默认值为 None
。它应将该参数的值 未修改地 存储在属性 random_state
中。 fit
方法可以调用 check_random_state
方法来获取实际的随机数生成器。如果出于某种原因,在 fit
之后需要随机性,
RNG 应该存储在一个属性 random_state_
中。
下面的示例应该能清楚地说明这一点:
class GaussianNoise(BaseEstimator, TransformerMixin):
"""这个估计器忽略其输入并返回随机的高斯噪声。
它也不遵循所有 scikit-learn 约定,
但展示了如何处理随机性。
"""
def __init__(self, n_components=100, random_state=None):
self.random_state = random_state
self.n_components = n_components
# 参数无论如何都会被忽略,所以我们将其设为可选
def fit(self, X=None, y=None):
self.random_state_ = check_random_state(self.random_state)
def transform(self, X):
n_samples = X.shape[0]
return self.random_state_.randn(n_samples, self.n_components)
这种设置的原因是可重复性:
当一个估计器被 fit
两次到相同的数据时,
它应该每次都产生相同的模型,
因此验证在 fit
中进行,而不是在 __init__
中。
测试中的数值断言#
当断言连续值数组的准等价性时,
请使用 sklearn.utils._testing.assert_allclose
。
相对容差会根据提供的数组的数据类型自动推断(特别是对于 float32 和 float64 数据类型),但你可以通过 rtol
覆盖。
当比较零元素数组时,请提供一个非零值作为绝对容差,通过 atol
。
更多信息,请参阅 sklearn.utils._testing.assert_allclose
的文档字符串。