基准测试 - 比较估计器性能#
benchmarking
模块允许你轻松地组织基准测试实验,在这些实验中,你希望比较一个或多个算法在一个或多个数据集和基准配置上的性能。
一般来说,基准测试很容易出错,从而得出关于估计器性能的错误结论——参见普林斯顿大学2022年的这项研究 <https://reproducible.cs.princeton.edu/>,其中提供了许多同行评审学术论文中此类错误的例子作为证据。
sktime
的 benchmarking
模块旨在提供基准测试功能,同时强制执行最佳实践和结构,以帮助用户避免犯错(如数据泄露等),从而使他们的结果无效。benchmarking
模块设计时考虑到了易用性,因此它直接与 sktime
对象和类接口。之前开发的估计器应无需修改即可使用。
本笔记本演示了 benchmarking
模块的使用。
[1]:
from sktime.benchmarking.forecasting import ForecastingBenchmark
from sktime.datasets import load_airline
from sktime.forecasting.naive import NaiveForecaster
from sktime.performance_metrics.forecasting import MeanSquaredPercentageError
from sktime.split import ExpandingWindowSplitter
实例化一个基准类实例#
在这个例子中,我们正在比较预测估计器。
[2]:
benchmark = ForecastingBenchmark()
添加竞争估计器#
我们将不同的竞争估计器添加到基准实例中。所有添加的估计器将自动通过每个添加的基准任务运行,并编译其结果。
[3]:
benchmark.add_estimator(
estimator=NaiveForecaster(strategy="mean", sp=12),
estimator_id="NaiveForecaster-mean-v1",
)
benchmark.add_estimator(
estimator=NaiveForecaster(strategy="last", sp=12),
estimator_id="NaiveForecaster-last-v1",
)
添加基准测试任务#
这些是每个估计器将被测试的预测/验证任务及其结果的汇总。
基准测试任务的确切参数取决于目标是预测、分类等,但通常它们是相似的。以下是定义预测基准测试任务所需的参数。
指定交叉验证分割方案#
使用标准的 sktime
对象定义交叉验证分割机制。
[4]:
cv_splitter = ExpandingWindowSplitter(
initial_window=24,
step_length=12,
fh=12,
)
指定性能指标#
使用标准的 sktime
对象定义用于比较估计器的性能指标。
[5]:
scorers = [MeanSquaredPercentageError()]
指定数据集加载器#
定义数据集加载器,这些是可调用对象(函数),应返回一个数据集。通常这是一个返回包含整个数据集的数据帧的可调用对象。可以使用 sktime
定义的数据集,或者定义自己的数据集。像下面这个简单的例子就足够了:
def my_dataset_loader():
return pd.read_csv("path/to/data.csv")
数据集将在运行基准测试任务时加载,经过交叉验证机制处理,随后评估器将在数据集分片上进行测试。
[6]:
dataset_loaders = [load_airline]
将任务添加到基准实例#
使用先前定义的对象将任务添加到基准实例。可选地使用循环等,以便轻松设置多个基准任务并重用参数。
[7]:
for dataset_loader in dataset_loaders:
benchmark.add_task(
dataset_loader,
cv_splitter,
scorers,
)
运行所有任务-估计器组合并存储结果#
注意,run
不会重新运行已经有结果的任务,因此添加一个新的估计器并再次运行 run
将只会运行该新估计器的任务。
[8]:
results_df = benchmark.run("./forecasting_results.csv")
results_df.T
[8]:
0 | 1 | |
---|---|---|
validation_id | [dataset=load_airline]_[cv_splitter=ExpandingW... | [dataset=load_airline]_[cv_splitter=ExpandingW... |
model_id | NaiveForecaster-last-v1 | NaiveForecaster-mean-v1 |
runtime_secs | 0.061472 | 0.081733 |
MeanSquaredPercentageError_fold_0_test | 0.024532 | 0.049681 |
MeanSquaredPercentageError_fold_1_test | 0.020831 | 0.0737 |
MeanSquaredPercentageError_fold_2_test | 0.001213 | 0.05352 |
MeanSquaredPercentageError_fold_3_test | 0.01495 | 0.081063 |
MeanSquaredPercentageError_fold_4_test | 0.031067 | 0.138163 |
MeanSquaredPercentageError_fold_5_test | 0.008373 | 0.145125 |
MeanSquaredPercentageError_fold_6_test | 0.007972 | 0.154337 |
MeanSquaredPercentageError_fold_7_test | 0.000009 | 0.123298 |
MeanSquaredPercentageError_fold_8_test | 0.028191 | 0.185644 |
MeanSquaredPercentageError_fold_9_test | 0.003906 | 0.184654 |
MeanSquaredPercentageError_mean | 0.014104 | 0.118918 |
MeanSquaredPercentageError_std | 0.011451 | 0.051265 |