实现估计器#
本页描述如何实现与 sktime
兼容的估计器,以及如何确保和测试兼容性。对于直接贡献给 sktime
的估计器,还有额外的步骤。
实现一个与 sktime
兼容的估计器#
实现与 sktime
兼容的估计器的高级步骤如下:
识别估计器的类型:预测器、分类器等
将该类型估计器的扩展模板复制到其目标位置
完成扩展模板
运行
sktime
测试套件和/或check_estimator
实用程序(参见 这里)如果测试套件突出显示了错误或问题,请修复它们并转到 4
有关如何实现自己的估计器的更多指导,请参阅这个关于测试接口一致性的 pydata 教程。
我的学习任务是什么?#
sktime
按照包含特定学习任务(例如,预测或时间序列分类)的模块进行结构化。为简洁起见,我们通过估算器解决的正式学习任务来定义其科学类型或“scitype”。例如,解决预测任务的估算器的scitype是“forecaster”。解决时间序列分类任务的估算器的scitype是“time series classifier”。
对于给定的科学类型,估计器应位于相应的模块中。估计器科学类型也映射到在 sktime
的 extension_templates 目录中找到的不同扩展模板。
通常,给定估计器的科学类型直接由估计器的作用决定。这通常也在与估计器相关的出版物中明确标示。例如,大多数教科书在预测的背景下提到ARIMA,所以在这种假设情况下,考虑“预测器”模板是有意义的。然后,检查模板并确认类的方法是否清晰地映射到估计器的例程上。如果不是,另一个模板可能更合适。
这里最常见的混淆点在于转换器和其他估计器类型之间,因为转换器通常作为其他类型算法的组成部分使用。
如果不确定,欢迎在 sktime
的社交频道上发布你的问题。不要慌张 - 学术出版物对估计器类型的描述不清晰并不罕见,即使是专家也可能难以正确分类。
什么是 sktime
扩展模板?#
扩展模板为新估计器的实现者提供了方便的“填充”模板。它们如下融入 sktime
的统一接口:
对于每种科学类型,都有一个由各自基类定义的公共用户接口。例如,
BaseForecaster
定义了预测器的fit
和predict
接口。所有预测器都将通过继承BaseForecaster
以相同的方式实现fit
和predict
。公共接口遵循“策略”面向对象模式。对于每种科学类型,都有一个私有的扩展接口,由扩展模板中的扩展契约定义。例如,
forecaster.py
预测器扩展模板解释了如何为一个从BaseForecaster
继承的具体预测器填写内容。在大多数扩展模板中,用户应实现私有方法(“内部”方法),例如,预测器的_fit
和_predict
。样板代码位于接口的公共部分,即fit
和predict
中。扩展接口遵循“模板”面向对象模式。
熟悉 scikit-learn
扩展的用户应注意以下与 scikit-learn
的不同之处:
公共接口,例如 fit
和 predict
,在 sktime``(具体)估计器中从未被重写。实现发生在私有、扩展侧接口中,例如 ``_fit
和 _predict
。
这避免了样板代码的重复,例如在 scikit-learn
中的 check_X
等。这也允许更丰富的样板代码,例如自动矢量化功能或输入转换。
如何使用 sktime
扩展模板#
要使用 sktime
扩展模板,请将它们复制到估计器的预期位置。在扩展模板内部,必要的操作用 todo
标记。典型的流程是通过搜索 todo
来浏览扩展模板,并执行紧随 todo
描述的操作。
扩展模板通常具有以下 todo
:
为估计器选择名称和参数
填充
__init__
:将参数写入self
,调用super
的__init__
填充模块和估计器的文档字符串。建议在参数确定后尽早进行,这通常作为实现时遵循的规范非常有用。
填写估计器的标签。一些标签是“能力”,即估计器能做什么,例如处理nans。其他标签决定了在“内部”方法``_fit``等中看到的输入格式,这些标签通常称为``X_inner_mtype``或类似名称。这在内部功能假设为``numpy.ndarray``或``pandas.DataFrame``时很有用,并有助于避免转换样板代码。类型字符串可以在``datatypes.MTYPE_REGISTER``中找到。有关数据类型约定的教程,请参见``examples/AA_datatypes_and_datasets``。
填充“内部”方法,例如
_fit
和_predict
。扩展模板中的文档字符串和注释应在此处遵循。文档字符串还描述了“内部”方法输入的保证,这些保证通常比公共方法输入的保证更强,并由已设置的标签值决定。例如,为预测器设置标签y_inner_mtype
为pd.DataFrame
可以保证_fit
看到的y
将是pandas.DataFrame
,符合sktime
中的额外数据容器规范(例如,索引类型)。在
get_test_params
中填写测试参数。参数的选择应涵盖主要估计器内部案例的区别,以实现良好的覆盖率。
一些常见的注意事项,也在扩展模板文本中描述:
__init__
参数应写入self
并且永远不应更改这种情况的特例:估计器组件,即作为估计器的参数,通常应该被克隆(通过
sklearn.clone
),并且方法应该只在克隆体上调用。方法通常应避免对参数产生副作用
非状态改变方法通常不应写入
self
通常情况下,实现
get_params
和set_params
是不需要的,因为sktime
的BaseEstimator
继承自sklearn
的。自定义的get_params
和set_params
通常只在复杂情况下需要,例如包含嵌套结构参数的异构组合,比如带有嵌套结构参数的管道。
如何测试接口一致性#
如需视频教程和更多示例,请访问我们的 pydata 教程。
使用 check_estimator
工具#
通常,测试与 sktime
接口一致性的最简单方法是使用 utils.estimator_checks
模块中的 check_estimator
方法。
当调用时,这将收集 sktime
中与估计器类型相关的测试,并在估计器上运行它们。
这可以在笔记本环境中用于手动调试。以下是运行 NaiveForecaster
完整测试套件的示例:
from sktime.utils.estimator_checks import check_estimator
from sktime.forecasting.naive import NaiveForecaster
check_estimator(NaiveForecaster)
check_estimator
工具默认会返回一个 dict
,该字典由测试/固定装置组合字符串索引,即测试名称和方括号中的固定装置组合字符串。例如:'test_repr[NaiveForecaster-2]'
,其中 test_repr
是测试名称,而 NaiveForecaster-2
是固定装置组合字符串。
返回 dict
的值要么是字符串 "PASSED"
,如果测试成功,要么是测试失败时会引发的异常。check_estimator
默认不引发异常,默认是将它们作为字典值返回。要引发异常,例如,用于调试,请使用参数 raise_exceptions=True
,这将引发异常而不是将它们作为字典值返回。在这种情况下,最多会引发一个异常,即测试执行顺序中遇到的第一个异常。
要运行或排除某些测试,请使用 tests_to_run
或 tests_to_exclude
参数。提供的值应为测试名称(字符串),或测试名称列表。请注意,测试名称不包括方括号部分。
示例,运行测试 test_constructor
并使用所有固定装置:
check_estimator(NaiveForecaster, tests_to_run="test_constructor")
{'test_constructor[NaiveForecaster]': 'PASSED'}
要运行或排除某些测试夹具组合,请使用 fixtures_to_run
或 fixtures_to_exclude
参数。提供的值应为测试夹具组合字符串(str)的名称,或此类字符串的列表。有效的字符串在使用默认参数的 check_estimator
时,正是字典的键。
示例,运行测试装置组合 "test_repr[NaiveForecaster-2]"
:
check_estimator(NaiveForecaster, fixtures_to_run="test_repr[NaiveForecaster-2]")
{'test_repr[NaiveForecaster-2]': 'PASSED'}
使用 check_estimator
调试估计器的一个有用工作流程如下:
运行
check_estimator(MyEstimator)
以查找失败的测试使用
fixtures_to_run
或tests_to_run
将子集设置为失败的测试或夹具如果失败不明显,设置
raise_exceptions=True
来引发异常并检查回溯。如果失败仍然不明确,请在包含
check_estimator
的代码行上使用高级调试器。
在仓库克隆中运行测试套件#
如果估计器的目标位置在 sktime
内,那么可以运行 sktime
测试套件。sktime
测试套件(和 CI/CD)基于 pytest
,pytest
会自动收集特定类型的所有估计器,并对给定的估计器应用测试。
关于测试框架的概述,请参阅“测试框架”文档。通用接口一致性测试包含在 TestAllEstimators
、TestAllForecasters
等类中。对于一个估计器 EstimatorName
,pytest
测试装置字符串将始终包含 EstimatorName
作为子字符串,并且与 check_estimator
返回的测试装置字符串相同。
要在控制台中仅运行给定估计器的测试,可以使用命令 pytest -k "EstimatorName"
。这通常与使用 check_estimator(EstimatorName)
具有相同的效果,只是通过直接的 pytest
调用。在使用 Visual Studio Code 或 pycharm 时,也可以使用 GUI 过滤功能对测试进行子集化 - 为此,请参阅相应 IDE 文档中的测试集成部分。
要识别应用于特定估计器的测试的代码库位置,一个快速的方法是在代码库中搜索由 check_estimator
生成的测试字符串,前面带有 ``def``(用于函数/方法定义)。
在第三方扩展包中进行测试#
对于第三方扩展包 sktime``(开源或闭源),或旨在与 ``sktime
接口兼容的第三方模块,可以通过以下方式导入和扩展 sktime
测试套件:
导入
check_estimator
,这将一次性执行sktime
中定义的测试。check_estimator
可以在任何测试框架中运行,包括unittest
和pytest
。从
sktime.utils.estimator_checks
导入parametrize_with_checks
。当在pytest
测试套件中使用时,这将用sktime
中为估计器类或实例列表定义的所有测试参数化一个测试函数,为每个估计器-测试组合运行单独的测试用例。此模式需要在测试套件中添加以下测试函数:from sktime.utils.estimator_checks import parametrize_with_checks @parametrize_with_checks(OBJS_TO_TEST) def test_sktime_api_compliance(obj, test_name): check_estimator(obj, tests_to_run=test_name, raise_exceptions=True)
导入测试类,例如
test_all_estimators.TestAllEstimators
或test_all_forecasters.TestAllForecasters
。这些导入将直接被pytest
发现。测试套件也可以通过继承测试类来扩展。
将一个与 sktime
兼容的估计器添加到 sktime
#
当将一个与 sktime
兼容的估计器添加到 sktime
本身时,需要完成一些额外的工作:
确保代码也符合
sktime
的 文档 标准。将估计器添加到
sktime
API 参考中。这是通过在docs/source/api_reference
内的正确rst
文件中添加对估计器的引用来完成的。估计器的作者应将自己添加到估计器的
"authors"
和"maintainers"
标签中,作为所贡献估计器的所有者。如果估计器依赖于软依赖项,或添加了新的软依赖项,应遵循 “依赖项”开发者指南 中的步骤
确保估计器通过
sktime
的整个本地测试套件,且估计器位于其目标位置。要仅对估计器运行测试,可以使用命令 ``pytest -k “EstimatorName”``(或使用 VS Code GUI 过滤功能)确保
get_test_params
中的测试参数选择使得估计器特定测试的运行时间保持在sktime
远程 CI/CD 的秒级顺序。
不要慌张 - 当为 sktime
贡献时,核心开发者会在他们的 PR 评审中提供上述内容的有益指导。
建议尽早开启一个草稿PR以获取反馈。
依赖于 cython 的估计器#
要在 sktime
中添加一个依赖于 cython 的估计器,需要以下额外步骤:
所有 Cython 代码应存在于
pypi
和/或conda-forge
上的一个单独包中。不应将依赖 Cython 的代码直接添加到sktime
中。为简单起见,我们在下面将这个单独的包称为home-package
。在
home-package
中,建议通过check_estimator
测试估计器,使用与sktime
相同的测试矩阵:所有支持的 Python 版本;MacOS、Linux、Windows。在
sktime
中,应添加算法的接口。如果home-package
中的算法已经通过了check_estimator
,这可以是一个简单的从home-package
导入。另外,算法可以通过委托器作为委托进行接口化,标签和方法重写可以在委托器中添加。例如,参见
MrSQM
。对于
sktime
接口,requires_cython
标签应设置为True
,python_dependencies
标签应设置为字符串"home-package"
。
如果所有设置都正确,估计器将在 sktime
中通过 CI 元素 test-cython-estimators
进行测试。请注意,此 CI 元素不涵盖完整的 Python 版本和操作系统的测试矩阵,这应在上游包中完成。