实现估计器#

本页描述如何实现与 sktime 兼容的估计器,以及如何确保和测试兼容性。对于直接贡献给 sktime 的估计器,还有额外的步骤。

实现一个与 sktime 兼容的估计器#

实现与 sktime 兼容的估计器的高级步骤如下:

  1. 识别估计器的类型:预测器、分类器等

  2. 将该类型估计器的扩展模板复制到其目标位置

  3. 完成扩展模板

  4. 运行 sktime 测试套件和/或 check_estimator 实用程序(参见 这里

  5. 如果测试套件突出显示了错误或问题,请修复它们并转到 4

有关如何实现自己的估计器的更多指导,请参阅这个关于测试接口一致性的 pydata 教程

我的学习任务是什么?#

sktime 按照包含特定学习任务(例如,预测或时间序列分类)的模块进行结构化。为简洁起见,我们通过估算器解决的正式学习任务来定义其科学类型或“scitype”。例如,解决预测任务的估算器的scitype是“forecaster”。解决时间序列分类任务的估算器的scitype是“time series classifier”。

对于给定的科学类型,估计器应位于相应的模块中。估计器科学类型也映射到在 sktimeextension_templates 目录中找到的不同扩展模板。

通常,给定估计器的科学类型直接由估计器的作用决定。这通常也在与估计器相关的出版物中明确标示。例如,大多数教科书在预测的背景下提到ARIMA,所以在这种假设情况下,考虑“预测器”模板是有意义的。然后,检查模板并确认类的方法是否清晰地映射到估计器的例程上。如果不是,另一个模板可能更合适。

这里最常见的混淆点在于转换器和其他估计器类型之间,因为转换器通常作为其他类型算法的组成部分使用。

如果不确定,欢迎在 sktime 的社交频道上发布你的问题。不要慌张 - 学术出版物对估计器类型的描述不清晰并不罕见,即使是专家也可能难以正确分类。

什么是 sktime 扩展模板?#

扩展模板为新估计器的实现者提供了方便的“填充”模板。它们如下融入 sktime 的统一接口:

  • 对于每种科学类型,都有一个由各自基类定义的公共用户接口。例如,BaseForecaster 定义了预测器的 fitpredict 接口。所有预测器都将通过继承 BaseForecaster 以相同的方式实现 fitpredict。公共接口遵循“策略”面向对象模式。

  • 对于每种科学类型,都有一个私有的扩展接口,由扩展模板中的扩展契约定义。例如,forecaster.py 预测器扩展模板解释了如何为一个从 BaseForecaster 继承的具体预测器填写内容。在大多数扩展模板中,用户应实现私有方法(“内部”方法),例如,预测器的 _fit_predict。样板代码位于接口的公共部分,即 fitpredict 中。扩展接口遵循“模板”面向对象模式。

熟悉 scikit-learn 扩展的用户应注意以下与 scikit-learn 的不同之处:

公共接口,例如 fitpredict,在 sktime``(具体)估计器中从未被重写。实现发生在私有、扩展侧接口中,例如 ``_fit_predict

这避免了样板代码的重复,例如在 scikit-learn 中的 check_X 等。这也允许更丰富的样板代码,例如自动矢量化功能或输入转换。

如何使用 sktime 扩展模板#

要使用 sktime 扩展模板,请将它们复制到估计器的预期位置。在扩展模板内部,必要的操作用 todo 标记。典型的流程是通过搜索 todo 来浏览扩展模板,并执行紧随 todo 描述的操作。

扩展模板通常具有以下 todo

  • 为估计器选择名称和参数

  • 填充 __init__:将参数写入 self,调用 super__init__

  • 填充模块和估计器的文档字符串。建议在参数确定后尽早进行,这通常作为实现时遵循的规范非常有用。

  • 填写估计器的标签。一些标签是“能力”,即估计器能做什么,例如处理nans。其他标签决定了在“内部”方法``_fit``等中看到的输入格式,这些标签通常称为``X_inner_mtype``或类似名称。这在内部功能假设为``numpy.ndarray``或``pandas.DataFrame``时很有用,并有助于避免转换样板代码。类型字符串可以在``datatypes.MTYPE_REGISTER``中找到。有关数据类型约定的教程,请参见``examples/AA_datatypes_and_datasets``。

  • 填充“内部”方法,例如 _fit_predict。扩展模板中的文档字符串和注释应在此处遵循。文档字符串还描述了“内部”方法输入的保证,这些保证通常比公共方法输入的保证更强,并由已设置的标签值决定。例如,为预测器设置标签 y_inner_mtypepd.DataFrame 可以保证 _fit 看到的 y 将是 pandas.DataFrame,符合 sktime 中的额外数据容器规范(例如,索引类型)。

  • get_test_params 中填写测试参数。参数的选择应涵盖主要估计器内部案例的区别,以实现良好的覆盖率。

一些常见的注意事项,也在扩展模板文本中描述:

  • __init__ 参数应写入 self 并且永远不应更改

  • 这种情况的特例:估计器组件,即作为估计器的参数,通常应该被克隆(通过 sklearn.clone),并且方法应该只在克隆体上调用。

  • 方法通常应避免对参数产生副作用

  • 非状态改变方法通常不应写入 self

  • 通常情况下,实现 get_paramsset_params 是不需要的,因为 sktimeBaseEstimator 继承自 sklearn 的。自定义的 get_paramsset_params 通常只在复杂情况下需要,例如包含嵌套结构参数的异构组合,比如带有嵌套结构参数的管道。

如何测试接口一致性#

如需视频教程和更多示例,请访问我们的 pydata 教程

使用 check_estimator 工具#

通常,测试与 sktime 接口一致性的最简单方法是使用 utils.estimator_checks 模块中的 check_estimator 方法。

当调用时,这将收集 sktime 中与估计器类型相关的测试,并在估计器上运行它们。

这可以在笔记本环境中用于手动调试。以下是运行 NaiveForecaster 完整测试套件的示例:

from sktime.utils.estimator_checks import check_estimator
from sktime.forecasting.naive import NaiveForecaster
check_estimator(NaiveForecaster)

check_estimator 工具默认会返回一个 dict,该字典由测试/固定装置组合字符串索引,即测试名称和方括号中的固定装置组合字符串。例如:'test_repr[NaiveForecaster-2]',其中 test_repr 是测试名称,而 NaiveForecaster-2 是固定装置组合字符串。

返回 dict 的值要么是字符串 "PASSED",如果测试成功,要么是测试失败时会引发的异常。check_estimator 默认不引发异常,默认是将它们作为字典值返回。要引发异常,例如,用于调试,请使用参数 raise_exceptions=True,这将引发异常而不是将它们作为字典值返回。在这种情况下,最多会引发一个异常,即测试执行顺序中遇到的第一个异常。

要运行或排除某些测试,请使用 tests_to_runtests_to_exclude 参数。提供的值应为测试名称(字符串),或测试名称列表。请注意,测试名称不包括方括号部分。

示例,运行测试 test_constructor 并使用所有固定装置:

check_estimator(NaiveForecaster, tests_to_run="test_constructor")

{'test_constructor[NaiveForecaster]': 'PASSED'}

要运行或排除某些测试夹具组合,请使用 fixtures_to_runfixtures_to_exclude 参数。提供的值应为测试夹具组合字符串(str)的名称,或此类字符串的列表。有效的字符串在使用默认参数的 check_estimator 时,正是字典的键。

示例,运行测试装置组合 "test_repr[NaiveForecaster-2]":

check_estimator(NaiveForecaster, fixtures_to_run="test_repr[NaiveForecaster-2]")

{'test_repr[NaiveForecaster-2]': 'PASSED'}

使用 check_estimator 调试估计器的一个有用工作流程如下:

  1. 运行 check_estimator(MyEstimator) 以查找失败的测试

  2. 使用 fixtures_to_runtests_to_run 将子集设置为失败的测试或夹具

  3. 如果失败不明显,设置 raise_exceptions=True 来引发异常并检查回溯。

  4. 如果失败仍然不明确,请在包含 check_estimator 的代码行上使用高级调试器。

在仓库克隆中运行测试套件#

如果估计器的目标位置在 sktime 内,那么可以运行 sktime 测试套件。sktime 测试套件(和 CI/CD)基于 pytestpytest 会自动收集特定类型的所有估计器,并对给定的估计器应用测试。

关于测试框架的概述,请参阅“测试框架”文档。通用接口一致性测试包含在 TestAllEstimatorsTestAllForecasters 等类中。对于一个估计器 EstimatorNamepytest 测试装置字符串将始终包含 EstimatorName 作为子字符串,并且与 check_estimator 返回的测试装置字符串相同。

要在控制台中仅运行给定估计器的测试,可以使用命令 pytest -k "EstimatorName"。这通常与使用 check_estimator(EstimatorName) 具有相同的效果,只是通过直接的 pytest 调用。在使用 Visual Studio Code 或 pycharm 时,也可以使用 GUI 过滤功能对测试进行子集化 - 为此,请参阅相应 IDE 文档中的测试集成部分。

要识别应用于特定估计器的测试的代码库位置,一个快速的方法是在代码库中搜索由 check_estimator 生成的测试字符串,前面带有 ``def``(用于函数/方法定义)。

在第三方扩展包中进行测试#

对于第三方扩展包 sktime``(开源或闭源),或旨在与 ``sktime 接口兼容的第三方模块,可以通过以下方式导入和扩展 sktime 测试套件:

  • 导入 check_estimator ,这将一次性执行 sktime 中定义的测试。check_estimator 可以在任何测试框架中运行,包括 unittestpytest

  • sktime.utils.estimator_checks 导入 parametrize_with_checks。当在 pytest 测试套件中使用时,这将用 sktime 中为估计器类或实例列表定义的所有测试参数化一个测试函数,为每个估计器-测试组合运行单独的测试用例。此模式需要在测试套件中添加以下测试函数:

    from sktime.utils.estimator_checks import parametrize_with_checks
    
    @parametrize_with_checks(OBJS_TO_TEST)
    def test_sktime_api_compliance(obj, test_name):
        check_estimator(obj, tests_to_run=test_name, raise_exceptions=True)
    
  • 导入测试类,例如 test_all_estimators.TestAllEstimatorstest_all_forecasters.TestAllForecasters。这些导入将直接被 pytest 发现。测试套件也可以通过继承测试类来扩展。

将一个与 sktime 兼容的估计器添加到 sktime#

当将一个与 sktime 兼容的估计器添加到 sktime 本身时,需要完成一些额外的工作:

  • 确保代码也符合 sktime文档 标准。

  • 将估计器添加到 sktime API 参考中。这是通过在 docs/source/api_reference 内的正确 rst 文件中添加对估计器的引用来完成的。

  • 估计器的作者应将自己添加到估计器的 "authors""maintainers" 标签中,作为所贡献估计器的所有者。

  • 如果估计器依赖于软依赖项,或添加了新的软依赖项,应遵循 “依赖项”开发者指南 中的步骤

  • 确保估计器通过 sktime 的整个本地测试套件,且估计器位于其目标位置。要仅对估计器运行测试,可以使用命令 ``pytest -k “EstimatorName”``(或使用 VS Code GUI 过滤功能)

  • 确保 get_test_params 中的测试参数选择使得估计器特定测试的运行时间保持在 sktime 远程 CI/CD 的秒级顺序。

不要慌张 - 当为 sktime 贡献时,核心开发者会在他们的 PR 评审中提供上述内容的有益指导。

建议尽早开启一个草稿PR以获取反馈。

依赖于 cython 的估计器#

要在 sktime 中添加一个依赖于 cython 的估计器,需要以下额外步骤:

  • 所有 Cython 代码应存在于 pypi 和/或 conda-forge 上的一个单独包中。不应将依赖 Cython 的代码直接添加到 sktime 中。为简单起见,我们在下面将这个单独的包称为 home-package

  • home-package 中,建议通过 check_estimator 测试估计器,使用与 sktime 相同的测试矩阵:所有支持的 Python 版本;MacOS、Linux、Windows。

  • sktime 中,应添加算法的接口。如果 home-package 中的算法已经通过了 check_estimator ,这可以是一个简单的从 home-package 导入。

  • 另外,算法可以通过委托器作为委托进行接口化,标签和方法重写可以在委托器中添加。例如,参见 MrSQM

  • 对于 sktime 接口,requires_cython 标签应设置为 Truepython_dependencies 标签应设置为字符串 "home-package"

如果所有设置都正确,估计器将在 sktime 中通过 CI 元素 test-cython-estimators 进行测试。请注意,此 CI 元素不涵盖完整的 Python 版本和操作系统的测试矩阵,这应在上游包中完成。