binder

基准测试 - 比较估计器性能#

benchmarking 模块允许你轻松地组织基准测试实验,在这些实验中,你希望比较一个或多个算法在一个或多个数据集和基准配置上的性能。

一般来说,基准测试很容易出错,从而得出关于估计器性能的错误结论——参见普林斯顿大学2022年的这项研究 <https://reproducible.cs.princeton.edu/>,其中提供了许多同行评审学术论文中此类错误的例子作为证据。

sktimebenchmarking 模块旨在提供基准测试功能,同时强制执行最佳实践和结构,以帮助用户避免犯错(如数据泄露等),从而使他们的结果无效。benchmarking 模块设计时考虑到了易用性,因此它直接与 sktime 对象和类接口。之前开发的估计器应无需修改即可使用。

本笔记本演示了 benchmarking 模块的使用。

[1]:
from sktime.benchmarking.forecasting import ForecastingBenchmark
from sktime.datasets import load_airline
from sktime.forecasting.naive import NaiveForecaster
from sktime.performance_metrics.forecasting import MeanSquaredPercentageError
from sktime.split import ExpandingWindowSplitter

实例化一个基准类实例#

在这个例子中,我们正在比较预测估计器。

[2]:
benchmark = ForecastingBenchmark()

添加竞争估计器#

我们将不同的竞争估计器添加到基准实例中。所有添加的估计器将自动通过每个添加的基准任务运行,并编译其结果。

[3]:
benchmark.add_estimator(
    estimator=NaiveForecaster(strategy="mean", sp=12),
    estimator_id="NaiveForecaster-mean-v1",
)
benchmark.add_estimator(
    estimator=NaiveForecaster(strategy="last", sp=12),
    estimator_id="NaiveForecaster-last-v1",
)

添加基准测试任务#

这些是每个估计器将被测试的预测/验证任务及其结果的汇总。

基准测试任务的确切参数取决于目标是预测、分类等,但通常它们是相似的。以下是定义预测基准测试任务所需的参数。

指定交叉验证分割方案#

使用标准的 sktime 对象定义交叉验证分割机制。

[4]:
cv_splitter = ExpandingWindowSplitter(
    initial_window=24,
    step_length=12,
    fh=12,
)

指定性能指标#

使用标准的 sktime 对象定义用于比较估计器的性能指标。

[5]:
scorers = [MeanSquaredPercentageError()]

指定数据集加载器#

定义数据集加载器,这些是可调用对象(函数),应返回一个数据集。通常这是一个返回包含整个数据集的数据帧的可调用对象。可以使用 sktime 定义的数据集,或者定义自己的数据集。像下面这个简单的例子就足够了:

def my_dataset_loader():
    return pd.read_csv("path/to/data.csv")

数据集将在运行基准测试任务时加载,经过交叉验证机制处理,随后评估器将在数据集分片上进行测试。

[6]:
dataset_loaders = [load_airline]

将任务添加到基准实例#

使用先前定义的对象将任务添加到基准实例。可选地使用循环等,以便轻松设置多个基准任务并重用参数。

[7]:
for dataset_loader in dataset_loaders:
    benchmark.add_task(
        dataset_loader,
        cv_splitter,
        scorers,
    )

运行所有任务-估计器组合并存储结果#

注意,run 不会重新运行已经有结果的任务,因此添加一个新的估计器并再次运行 run 将只会运行该新估计器的任务。

[8]:
results_df = benchmark.run("./forecasting_results.csv")
results_df.T
[8]:
0 1
validation_id [dataset=load_airline]_[cv_splitter=ExpandingW... [dataset=load_airline]_[cv_splitter=ExpandingW...
model_id NaiveForecaster-last-v1 NaiveForecaster-mean-v1
runtime_secs 0.061472 0.081733
MeanSquaredPercentageError_fold_0_test 0.024532 0.049681
MeanSquaredPercentageError_fold_1_test 0.020831 0.0737
MeanSquaredPercentageError_fold_2_test 0.001213 0.05352
MeanSquaredPercentageError_fold_3_test 0.01495 0.081063
MeanSquaredPercentageError_fold_4_test 0.031067 0.138163
MeanSquaredPercentageError_fold_5_test 0.008373 0.145125
MeanSquaredPercentageError_fold_6_test 0.007972 0.154337
MeanSquaredPercentageError_fold_7_test 0.000009 0.123298
MeanSquaredPercentageError_fold_8_test 0.028191 0.185644
MeanSquaredPercentageError_fold_9_test 0.003906 0.184654
MeanSquaredPercentageError_mean 0.014104 0.118918
MeanSquaredPercentageError_std 0.011451 0.051265

使用 nbsphinx 生成。Jupyter 笔记本可以在 这里 找到。