非负最小迹(MinTrace)

大型时间序列集合通常以不同的聚合级别组织,通常要求它们的预测遵循聚合约束并且非负,这就带来了创建能够生成一致预测的新算法的挑战。

HierarchicalForecast 包提供了一套广泛的 Python 实现的层次预测算法,这些算法遵循非负层次协调。

在本笔记本中,我们将展示如何使用 HierarchicalForecast 包对 Wiki2 数据集执行非负预测协调。

您可以使用 CPU 或 GPU 在 Google Colab 中运行这些实验。

在 Colab 中打开

%%capture
!pip install hierarchicalforecast statsforecast datasetsforecast

1. 加载数据

在此示例中,我们将使用Wiki2数据集。以下单元获取层级中的不同级别的时间序列,求和数据框S_df从底层层级恢复完整数据集,以及每个层级的索引,称为tags

import numpy as np
import pandas as pd

from datasetsforecast.hierarchical import HierarchicalData
Y_df, S_df, tags = HierarchicalData.load('./data', 'Wiki2')
Y_df['ds'] = pd.to_datetime(Y_df['ds'])
Y_df.head()
unique_id ds y
0 Total 2016-01-01 156508
1 Total 2016-01-02 129902
2 Total 2016-01-03 138203
3 Total 2016-01-04 115017
4 Total 2016-01-05 126042
S_df.iloc[:5, :5]
de_AAC_AAG_001 de_AAC_AAG_010 de_AAC_AAG_014 de_AAC_AAG_045 de_AAC_AAG_063
Total 1 1 1 1 1
de 1 1 1 1 1
en 0 0 0 0 0
fr 0 0 0 0 0
ja 0 0 0 0 0
tags
{'Views': array(['Total'], dtype=object),
 'Views/Country': array(['de', 'en', 'fr', 'ja', 'ru', 'zh'], dtype=object),
 'Views/Country/Access': array(['de_AAC', 'de_DES', 'de_MOB', 'en_AAC', 'en_DES', 'en_MOB',
        'fr_AAC', 'fr_DES', 'fr_MOB', 'ja_AAC', 'ja_DES', 'ja_MOB',
        'ru_AAC', 'ru_DES', 'ru_MOB', 'zh_AAC', 'zh_DES', 'zh_MOB'],
       dtype=object),
 'Views/Country/Access/Agent': array(['de_AAC_AAG', 'de_AAC_SPD', 'de_DES_AAG', 'de_MOB_AAG',
        'en_AAC_AAG', 'en_AAC_SPD', 'en_DES_AAG', 'en_MOB_AAG',
        'fr_AAC_AAG', 'fr_AAC_SPD', 'fr_DES_AAG', 'fr_MOB_AAG',
        'ja_AAC_AAG', 'ja_AAC_SPD', 'ja_DES_AAG', 'ja_MOB_AAG',
        'ru_AAC_AAG', 'ru_AAC_SPD', 'ru_DES_AAG', 'ru_MOB_AAG',
        'zh_AAC_AAG', 'zh_AAC_SPD', 'zh_DES_AAG', 'zh_MOB_AAG'],
       dtype=object),
 'Views/Country/Access/Agent/Topic': array(['de_AAC_AAG_001', 'de_AAC_AAG_010', 'de_AAC_AAG_014',
        'de_AAC_AAG_045', 'de_AAC_AAG_063', 'de_AAC_AAG_100',
        'de_AAC_AAG_110', 'de_AAC_AAG_123', 'de_AAC_AAG_143',
        'de_AAC_SPD_012', 'de_AAC_SPD_074', 'de_AAC_SPD_080',
        'de_AAC_SPD_105', 'de_AAC_SPD_115', 'de_AAC_SPD_133',
        'de_DES_AAG_064', 'de_DES_AAG_116', 'de_DES_AAG_131',
        'de_MOB_AAG_015', 'de_MOB_AAG_020', 'de_MOB_AAG_032',
        'de_MOB_AAG_059', 'de_MOB_AAG_062', 'de_MOB_AAG_088',
        'de_MOB_AAG_095', 'de_MOB_AAG_109', 'de_MOB_AAG_122',
        'de_MOB_AAG_149', 'en_AAC_AAG_044', 'en_AAC_AAG_049',
        'en_AAC_AAG_075', 'en_AAC_AAG_114', 'en_AAC_AAG_119',
        'en_AAC_AAG_141', 'en_AAC_SPD_004', 'en_AAC_SPD_011',
        'en_AAC_SPD_026', 'en_AAC_SPD_048', 'en_AAC_SPD_067',
        'en_AAC_SPD_126', 'en_AAC_SPD_140', 'en_DES_AAG_016',
        'en_DES_AAG_024', 'en_DES_AAG_042', 'en_DES_AAG_069',
        'en_DES_AAG_082', 'en_DES_AAG_102', 'en_MOB_AAG_018',
        'en_MOB_AAG_022', 'en_MOB_AAG_101', 'en_MOB_AAG_124',
        'fr_AAC_AAG_029', 'fr_AAC_AAG_046', 'fr_AAC_AAG_070',
        'fr_AAC_AAG_087', 'fr_AAC_AAG_098', 'fr_AAC_AAG_104',
        'fr_AAC_AAG_111', 'fr_AAC_AAG_112', 'fr_AAC_AAG_142',
        'fr_AAC_SPD_025', 'fr_AAC_SPD_027', 'fr_AAC_SPD_035',
        'fr_AAC_SPD_077', 'fr_AAC_SPD_084', 'fr_AAC_SPD_097',
        'fr_AAC_SPD_130', 'fr_DES_AAG_023', 'fr_DES_AAG_043',
        'fr_DES_AAG_051', 'fr_DES_AAG_058', 'fr_DES_AAG_061',
        'fr_DES_AAG_091', 'fr_DES_AAG_093', 'fr_DES_AAG_094',
        'fr_DES_AAG_136', 'fr_MOB_AAG_006', 'fr_MOB_AAG_030',
        'fr_MOB_AAG_066', 'fr_MOB_AAG_117', 'fr_MOB_AAG_120',
        'fr_MOB_AAG_121', 'fr_MOB_AAG_135', 'fr_MOB_AAG_147',
        'ja_AAC_AAG_038', 'ja_AAC_AAG_047', 'ja_AAC_AAG_055',
        'ja_AAC_AAG_076', 'ja_AAC_AAG_099', 'ja_AAC_AAG_128',
        'ja_AAC_AAG_132', 'ja_AAC_AAG_134', 'ja_AAC_AAG_137',
        'ja_AAC_SPD_013', 'ja_AAC_SPD_034', 'ja_AAC_SPD_050',
        'ja_AAC_SPD_060', 'ja_AAC_SPD_078', 'ja_AAC_SPD_106',
        'ja_DES_AAG_079', 'ja_DES_AAG_081', 'ja_DES_AAG_113',
        'ja_MOB_AAG_065', 'ja_MOB_AAG_073', 'ja_MOB_AAG_092',
        'ja_MOB_AAG_127', 'ja_MOB_AAG_129', 'ja_MOB_AAG_144',
        'ru_AAC_AAG_008', 'ru_AAC_AAG_145', 'ru_AAC_AAG_146',
        'ru_AAC_SPD_000', 'ru_AAC_SPD_090', 'ru_AAC_SPD_148',
        'ru_DES_AAG_003', 'ru_DES_AAG_007', 'ru_DES_AAG_017',
        'ru_DES_AAG_041', 'ru_DES_AAG_071', 'ru_DES_AAG_072',
        'ru_MOB_AAG_002', 'ru_MOB_AAG_040', 'ru_MOB_AAG_083',
        'ru_MOB_AAG_086', 'ru_MOB_AAG_103', 'ru_MOB_AAG_107',
        'ru_MOB_AAG_118', 'ru_MOB_AAG_125', 'zh_AAC_AAG_021',
        'zh_AAC_AAG_033', 'zh_AAC_AAG_037', 'zh_AAC_AAG_052',
        'zh_AAC_AAG_057', 'zh_AAC_AAG_085', 'zh_AAC_AAG_108',
        'zh_AAC_SPD_039', 'zh_AAC_SPD_096', 'zh_DES_AAG_009',
        'zh_DES_AAG_019', 'zh_DES_AAG_053', 'zh_DES_AAG_054',
        'zh_DES_AAG_056', 'zh_DES_AAG_068', 'zh_DES_AAG_089',
        'zh_DES_AAG_139', 'zh_MOB_AAG_005', 'zh_MOB_AAG_028',
        'zh_MOB_AAG_031', 'zh_MOB_AAG_036', 'zh_MOB_AAG_138'], dtype=object)}

我们将数据框分为训练/测试集。

Y_test_df = Y_df.groupby('unique_id').tail(7)
Y_train_df = Y_df.drop(Y_test_df.index)
Y_test_df = Y_test_df.set_index('unique_id')
Y_train_df = Y_train_df.set_index('unique_id')

2. 基础预测

下面的单元格使用 ETSnaive 模型计算每个时间序列的 基础预测。请注意,Y_hat_df 包含预测值,但它们并不一致。

%%capture
from statsforecast.models import ETS, Naive
from statsforecast.core import StatsForecast
%%capture
fcst = StatsForecast(
    df=Y_train_df, 
    models=[ETS(season_length=7, model='ZAA'), Naive()], 
    freq='D', 
    n_jobs=-1
)
Y_hat_df = fcst.forecast(h=7)

注意到ETS模型对一些序列计算出负预测值。

Y_hat_df.query('ETS < 0')
ds ETS Naive
unique_id
de_AAC_AAG_001 2016-12-25 -487.601532 340.0
de_AAC_AAG_001 2016-12-26 -215.634201 340.0
de_AAC_AAG_001 2016-12-27 -173.175613 340.0
de_AAC_AAG_001 2016-12-30 -290.836060 340.0
de_AAC_AAG_001 2016-12-31 -784.441040 340.0
... ... ... ...
zh_AAC_AAG_033 2016-12-31 -86.526421 37.0
zh_MOB 2016-12-26 -199.534882 1036.0
zh_MOB 2016-12-27 -69.527260 1036.0
zh_MOB_AAG 2016-12-26 -199.534882 1036.0
zh_MOB_AAG 2016-12-27 -69.527260 1036.0

99 rows × 3 columns

3. 非负协调

以下单元格使用HierarchicalReconciliation类使之前的预测结果保持一致且非负。

from hierarchicalforecast.methods import MinTrace
from hierarchicalforecast.core import HierarchicalReconciliation
%%capture
reconcilers = [
    MinTrace(method='ols'),
    MinTrace(method='ols', nonnegative=True)
]
hrec = HierarchicalReconciliation(reconcilers=reconcilers)
Y_rec_df = hrec.reconcile(Y_hat_df=Y_hat_df, Y_df=Y_train_df,
                          S=S_df, tags=tags)

请注意,非负和解法可以获得非负的预测。

Y_rec_df.query('`ETS/MinTrace_method-ols_nonnegative-True` < 0')
ds ETS Naive ETS/MinTrace_method-ols Naive/MinTrace_method-ols ETS/MinTrace_method-ols_nonnegative-True Naive/MinTrace_method-ols_nonnegative-True
unique_id

自由对账法会产生负预测。

Y_rec_df.query('`ETS/MinTrace_method-ols` < 0')
ds ETS Naive ETS/MinTrace_method-ols Naive/MinTrace_method-ols ETS/MinTrace_method-ols_nonnegative-True Naive/MinTrace_method-ols_nonnegative-True
unique_id
de_DES 2016-12-25 -2553.932861 495.0 -3468.745214 495.0 2.262540e-15 495.0
de_DES 2016-12-26 -2155.228271 495.0 -2985.587125 495.0 1.356705e-30 495.0
de_DES 2016-12-27 -2720.993896 495.0 -3698.680055 495.0 6.857413e-30 495.0
de_DES 2016-12-29 -3429.432617 495.0 -2965.207609 495.0 2.456449e+02 495.0
de_DES 2016-12-30 -3963.202637 495.0 -3217.360371 495.0 3.646790e+02 495.0
... ... ... ... ... ... ... ...
zh_MOB_AAG_036 2016-12-26 75.298317 115.0 -165.799776 115.0 3.207772e-14 115.0
zh_MOB_AAG_036 2016-12-27 72.895554 115.0 -134.340626 115.0 2.308198e-14 115.0
zh_MOB_AAG_138 2016-12-25 94.796623 65.0 -47.009813 65.0 3.116938e-14 65.0
zh_MOB_AAG_138 2016-12-26 71.293983 65.0 -169.804110 65.0 0.000000e+00 65.0
zh_MOB_AAG_138 2016-12-27 62.049744 65.0 -145.186436 65.0 0.000000e+00 65.0

240 rows × 7 columns

4. 评估

HierarchicalForecast 包含 HierarchicalEvaluation 类,用于评估不同的层次结构,并能够计算相对于基准模型的缩放指标。

from hierarchicalforecast.evaluation import HierarchicalEvaluation
def mse(y, y_hat):
    return np.mean((y-y_hat)**2)

evaluator = HierarchicalEvaluation(evaluators=[mse])
evaluation = evaluator.evaluate(
        Y_hat_df=Y_rec_df, Y_test_df=Y_test_df, 
        tags=tags, benchmark='Naive'
)
evaluation.filter(like='ETS', axis=1).T
level Overall Views Views/Country Views/Country/Access Views/Country/Access/Agent Views/Country/Access/Agent/Topic
metric mse-scaled mse-scaled mse-scaled mse-scaled mse-scaled mse-scaled
ETS 1.011585 0.7358 1.190354 1.103657 1.089515 1.397139
ETS/MinTrace_method-ols 0.979163 0.698355 1.062521 1.143277 1.113349 1.354041
ETS/MinTrace_method-ols_nonnegative-True 0.945075 0.677892 1.004639 1.184719 1.141442 1.158672

请注意,非负和解法的表现优于其无约束的对应方法。

参考文献

If you find the code useful, please ⭐ us on Github