%%capture
!pip install hierarchicalforecast statsforecast datasetsforecast
非负最小迹(MinTrace)
大型时间序列集合通常以不同的聚合级别组织,通常要求它们的预测遵循聚合约束并且非负,这就带来了创建能够生成一致预测的新算法的挑战。
HierarchicalForecast
包提供了一套广泛的 Python 实现的层次预测算法,这些算法遵循非负层次协调。
在本笔记本中,我们将展示如何使用 HierarchicalForecast
包对 Wiki2
数据集执行非负预测协调。
您可以使用 CPU 或 GPU 在 Google Colab 中运行这些实验。
1. 加载数据
在此示例中,我们将使用Wiki2
数据集。以下单元获取层级中的不同级别的时间序列,求和数据框S_df
从底层层级恢复完整数据集,以及每个层级的索引,称为tags
。
import numpy as np
import pandas as pd
from datasetsforecast.hierarchical import HierarchicalData
= HierarchicalData.load('./data', 'Wiki2')
Y_df, S_df, tags 'ds'] = pd.to_datetime(Y_df['ds']) Y_df[
Y_df.head()
unique_id | ds | y | |
---|---|---|---|
0 | Total | 2016-01-01 | 156508 |
1 | Total | 2016-01-02 | 129902 |
2 | Total | 2016-01-03 | 138203 |
3 | Total | 2016-01-04 | 115017 |
4 | Total | 2016-01-05 | 126042 |
5, :5] S_df.iloc[:
de_AAC_AAG_001 | de_AAC_AAG_010 | de_AAC_AAG_014 | de_AAC_AAG_045 | de_AAC_AAG_063 | |
---|---|---|---|---|---|
Total | 1 | 1 | 1 | 1 | 1 |
de | 1 | 1 | 1 | 1 | 1 |
en | 0 | 0 | 0 | 0 | 0 |
fr | 0 | 0 | 0 | 0 | 0 |
ja | 0 | 0 | 0 | 0 | 0 |
tags
{'Views': array(['Total'], dtype=object),
'Views/Country': array(['de', 'en', 'fr', 'ja', 'ru', 'zh'], dtype=object),
'Views/Country/Access': array(['de_AAC', 'de_DES', 'de_MOB', 'en_AAC', 'en_DES', 'en_MOB',
'fr_AAC', 'fr_DES', 'fr_MOB', 'ja_AAC', 'ja_DES', 'ja_MOB',
'ru_AAC', 'ru_DES', 'ru_MOB', 'zh_AAC', 'zh_DES', 'zh_MOB'],
dtype=object),
'Views/Country/Access/Agent': array(['de_AAC_AAG', 'de_AAC_SPD', 'de_DES_AAG', 'de_MOB_AAG',
'en_AAC_AAG', 'en_AAC_SPD', 'en_DES_AAG', 'en_MOB_AAG',
'fr_AAC_AAG', 'fr_AAC_SPD', 'fr_DES_AAG', 'fr_MOB_AAG',
'ja_AAC_AAG', 'ja_AAC_SPD', 'ja_DES_AAG', 'ja_MOB_AAG',
'ru_AAC_AAG', 'ru_AAC_SPD', 'ru_DES_AAG', 'ru_MOB_AAG',
'zh_AAC_AAG', 'zh_AAC_SPD', 'zh_DES_AAG', 'zh_MOB_AAG'],
dtype=object),
'Views/Country/Access/Agent/Topic': array(['de_AAC_AAG_001', 'de_AAC_AAG_010', 'de_AAC_AAG_014',
'de_AAC_AAG_045', 'de_AAC_AAG_063', 'de_AAC_AAG_100',
'de_AAC_AAG_110', 'de_AAC_AAG_123', 'de_AAC_AAG_143',
'de_AAC_SPD_012', 'de_AAC_SPD_074', 'de_AAC_SPD_080',
'de_AAC_SPD_105', 'de_AAC_SPD_115', 'de_AAC_SPD_133',
'de_DES_AAG_064', 'de_DES_AAG_116', 'de_DES_AAG_131',
'de_MOB_AAG_015', 'de_MOB_AAG_020', 'de_MOB_AAG_032',
'de_MOB_AAG_059', 'de_MOB_AAG_062', 'de_MOB_AAG_088',
'de_MOB_AAG_095', 'de_MOB_AAG_109', 'de_MOB_AAG_122',
'de_MOB_AAG_149', 'en_AAC_AAG_044', 'en_AAC_AAG_049',
'en_AAC_AAG_075', 'en_AAC_AAG_114', 'en_AAC_AAG_119',
'en_AAC_AAG_141', 'en_AAC_SPD_004', 'en_AAC_SPD_011',
'en_AAC_SPD_026', 'en_AAC_SPD_048', 'en_AAC_SPD_067',
'en_AAC_SPD_126', 'en_AAC_SPD_140', 'en_DES_AAG_016',
'en_DES_AAG_024', 'en_DES_AAG_042', 'en_DES_AAG_069',
'en_DES_AAG_082', 'en_DES_AAG_102', 'en_MOB_AAG_018',
'en_MOB_AAG_022', 'en_MOB_AAG_101', 'en_MOB_AAG_124',
'fr_AAC_AAG_029', 'fr_AAC_AAG_046', 'fr_AAC_AAG_070',
'fr_AAC_AAG_087', 'fr_AAC_AAG_098', 'fr_AAC_AAG_104',
'fr_AAC_AAG_111', 'fr_AAC_AAG_112', 'fr_AAC_AAG_142',
'fr_AAC_SPD_025', 'fr_AAC_SPD_027', 'fr_AAC_SPD_035',
'fr_AAC_SPD_077', 'fr_AAC_SPD_084', 'fr_AAC_SPD_097',
'fr_AAC_SPD_130', 'fr_DES_AAG_023', 'fr_DES_AAG_043',
'fr_DES_AAG_051', 'fr_DES_AAG_058', 'fr_DES_AAG_061',
'fr_DES_AAG_091', 'fr_DES_AAG_093', 'fr_DES_AAG_094',
'fr_DES_AAG_136', 'fr_MOB_AAG_006', 'fr_MOB_AAG_030',
'fr_MOB_AAG_066', 'fr_MOB_AAG_117', 'fr_MOB_AAG_120',
'fr_MOB_AAG_121', 'fr_MOB_AAG_135', 'fr_MOB_AAG_147',
'ja_AAC_AAG_038', 'ja_AAC_AAG_047', 'ja_AAC_AAG_055',
'ja_AAC_AAG_076', 'ja_AAC_AAG_099', 'ja_AAC_AAG_128',
'ja_AAC_AAG_132', 'ja_AAC_AAG_134', 'ja_AAC_AAG_137',
'ja_AAC_SPD_013', 'ja_AAC_SPD_034', 'ja_AAC_SPD_050',
'ja_AAC_SPD_060', 'ja_AAC_SPD_078', 'ja_AAC_SPD_106',
'ja_DES_AAG_079', 'ja_DES_AAG_081', 'ja_DES_AAG_113',
'ja_MOB_AAG_065', 'ja_MOB_AAG_073', 'ja_MOB_AAG_092',
'ja_MOB_AAG_127', 'ja_MOB_AAG_129', 'ja_MOB_AAG_144',
'ru_AAC_AAG_008', 'ru_AAC_AAG_145', 'ru_AAC_AAG_146',
'ru_AAC_SPD_000', 'ru_AAC_SPD_090', 'ru_AAC_SPD_148',
'ru_DES_AAG_003', 'ru_DES_AAG_007', 'ru_DES_AAG_017',
'ru_DES_AAG_041', 'ru_DES_AAG_071', 'ru_DES_AAG_072',
'ru_MOB_AAG_002', 'ru_MOB_AAG_040', 'ru_MOB_AAG_083',
'ru_MOB_AAG_086', 'ru_MOB_AAG_103', 'ru_MOB_AAG_107',
'ru_MOB_AAG_118', 'ru_MOB_AAG_125', 'zh_AAC_AAG_021',
'zh_AAC_AAG_033', 'zh_AAC_AAG_037', 'zh_AAC_AAG_052',
'zh_AAC_AAG_057', 'zh_AAC_AAG_085', 'zh_AAC_AAG_108',
'zh_AAC_SPD_039', 'zh_AAC_SPD_096', 'zh_DES_AAG_009',
'zh_DES_AAG_019', 'zh_DES_AAG_053', 'zh_DES_AAG_054',
'zh_DES_AAG_056', 'zh_DES_AAG_068', 'zh_DES_AAG_089',
'zh_DES_AAG_139', 'zh_MOB_AAG_005', 'zh_MOB_AAG_028',
'zh_MOB_AAG_031', 'zh_MOB_AAG_036', 'zh_MOB_AAG_138'], dtype=object)}
我们将数据框分为训练/测试集。
= Y_df.groupby('unique_id').tail(7)
Y_test_df = Y_df.drop(Y_test_df.index) Y_train_df
= Y_test_df.set_index('unique_id')
Y_test_df = Y_train_df.set_index('unique_id') Y_train_df
2. 基础预测
下面的单元格使用 ETS
和 naive
模型计算每个时间序列的 基础预测。请注意,Y_hat_df
包含预测值,但它们并不一致。
%%capture
from statsforecast.models import ETS, Naive
from statsforecast.core import StatsForecast
%%capture
= StatsForecast(
fcst =Y_train_df,
df=[ETS(season_length=7, model='ZAA'), Naive()],
models='D',
freq=-1
n_jobs
)= fcst.forecast(h=7) Y_hat_df
注意到ETS模型对一些序列计算出负预测值。
'ETS < 0') Y_hat_df.query(
ds | ETS | Naive | |
---|---|---|---|
unique_id | |||
de_AAC_AAG_001 | 2016-12-25 | -487.601532 | 340.0 |
de_AAC_AAG_001 | 2016-12-26 | -215.634201 | 340.0 |
de_AAC_AAG_001 | 2016-12-27 | -173.175613 | 340.0 |
de_AAC_AAG_001 | 2016-12-30 | -290.836060 | 340.0 |
de_AAC_AAG_001 | 2016-12-31 | -784.441040 | 340.0 |
... | ... | ... | ... |
zh_AAC_AAG_033 | 2016-12-31 | -86.526421 | 37.0 |
zh_MOB | 2016-12-26 | -199.534882 | 1036.0 |
zh_MOB | 2016-12-27 | -69.527260 | 1036.0 |
zh_MOB_AAG | 2016-12-26 | -199.534882 | 1036.0 |
zh_MOB_AAG | 2016-12-27 | -69.527260 | 1036.0 |
99 rows × 3 columns
3. 非负协调
以下单元格使用HierarchicalReconciliation
类使之前的预测结果保持一致且非负。
from hierarchicalforecast.methods import MinTrace
from hierarchicalforecast.core import HierarchicalReconciliation
%%capture
= [
reconcilers ='ols'),
MinTrace(method='ols', nonnegative=True)
MinTrace(method
]= HierarchicalReconciliation(reconcilers=reconcilers)
hrec = hrec.reconcile(Y_hat_df=Y_hat_df, Y_df=Y_train_df,
Y_rec_df =S_df, tags=tags) S
请注意,非负和解法可以获得非负的预测。
'`ETS/MinTrace_method-ols_nonnegative-True` < 0') Y_rec_df.query(
ds | ETS | Naive | ETS/MinTrace_method-ols | Naive/MinTrace_method-ols | ETS/MinTrace_method-ols_nonnegative-True | Naive/MinTrace_method-ols_nonnegative-True | |
---|---|---|---|---|---|---|---|
unique_id |
自由对账法会产生负预测。
'`ETS/MinTrace_method-ols` < 0') Y_rec_df.query(
ds | ETS | Naive | ETS/MinTrace_method-ols | Naive/MinTrace_method-ols | ETS/MinTrace_method-ols_nonnegative-True | Naive/MinTrace_method-ols_nonnegative-True | |
---|---|---|---|---|---|---|---|
unique_id | |||||||
de_DES | 2016-12-25 | -2553.932861 | 495.0 | -3468.745214 | 495.0 | 2.262540e-15 | 495.0 |
de_DES | 2016-12-26 | -2155.228271 | 495.0 | -2985.587125 | 495.0 | 1.356705e-30 | 495.0 |
de_DES | 2016-12-27 | -2720.993896 | 495.0 | -3698.680055 | 495.0 | 6.857413e-30 | 495.0 |
de_DES | 2016-12-29 | -3429.432617 | 495.0 | -2965.207609 | 495.0 | 2.456449e+02 | 495.0 |
de_DES | 2016-12-30 | -3963.202637 | 495.0 | -3217.360371 | 495.0 | 3.646790e+02 | 495.0 |
... | ... | ... | ... | ... | ... | ... | ... |
zh_MOB_AAG_036 | 2016-12-26 | 75.298317 | 115.0 | -165.799776 | 115.0 | 3.207772e-14 | 115.0 |
zh_MOB_AAG_036 | 2016-12-27 | 72.895554 | 115.0 | -134.340626 | 115.0 | 2.308198e-14 | 115.0 |
zh_MOB_AAG_138 | 2016-12-25 | 94.796623 | 65.0 | -47.009813 | 65.0 | 3.116938e-14 | 65.0 |
zh_MOB_AAG_138 | 2016-12-26 | 71.293983 | 65.0 | -169.804110 | 65.0 | 0.000000e+00 | 65.0 |
zh_MOB_AAG_138 | 2016-12-27 | 62.049744 | 65.0 | -145.186436 | 65.0 | 0.000000e+00 | 65.0 |
240 rows × 7 columns
4. 评估
HierarchicalForecast
包含 HierarchicalEvaluation
类,用于评估不同的层次结构,并能够计算相对于基准模型的缩放指标。
from hierarchicalforecast.evaluation import HierarchicalEvaluation
def mse(y, y_hat):
return np.mean((y-y_hat)**2)
= HierarchicalEvaluation(evaluators=[mse])
evaluator = evaluator.evaluate(
evaluation =Y_rec_df, Y_test_df=Y_test_df,
Y_hat_df=tags, benchmark='Naive'
tags
)filter(like='ETS', axis=1).T evaluation.
level | Overall | Views | Views/Country | Views/Country/Access | Views/Country/Access/Agent | Views/Country/Access/Agent/Topic |
---|---|---|---|---|---|---|
metric | mse-scaled | mse-scaled | mse-scaled | mse-scaled | mse-scaled | mse-scaled |
ETS | 1.011585 | 0.7358 | 1.190354 | 1.103657 | 1.089515 | 1.397139 |
ETS/MinTrace_method-ols | 0.979163 | 0.698355 | 1.062521 | 1.143277 | 1.113349 | 1.354041 |
ETS/MinTrace_method-ols_nonnegative-True | 0.945075 | 0.677892 | 1.004639 | 1.184719 | 1.141442 | 1.158672 |
请注意,非负和解法的表现优于其无约束的对应方法。
参考文献
- Hyndman, R.J. 和 Athanasopoulos, G. (2021). “预测:原理与实践,第3版: 第11章:预测分层和分组序列。” OTexts:澳大利亚墨尔本。OTexts.com/fpp3 访问于2022年7月。
- Wickramasuriya, S.L., Athanasopoulos, G. 和 Hyndman, R.J. (2019). "通过迹最小化进行 分层和分组时间序列的最佳预测调整"。美国统计协会杂志, 114,804–819。doi:10.1080/01621459.2018.1448825。.
- Wickramasuriya, S.L., Turlach, B.A. 和 Hyndman, R.J. (2020). "最佳非负 预测调整”。统计计算 30, 1167–1182, https://doi.org/10.1007/s11222-020-09930-0.
If you find the code useful, please ⭐ us on Github