# %pip 安装 datasetsforecast hierarchicalforecast mlforecast neuralforecast
神经网络/机器学习预测
这个示例笔记本演示了HierarchicalForecast的调和方法与流行的机器学习库的兼容性,特别是NeuralForecast和MLForecast。
该笔记本利用NBEATS和XGBRegressor模型为TourismLarge层次数据集创建基础预测。之后,我们使用HierarchicalForecast来调和基础预测结果。
参考文献
- Boris N. Oreshkin, Dmitri Carpov, Nicolas Chapados, Yoshua Bengio(2019)。“N-BEATS:用于可解释时间序列预测的神经基础扩展分析”。网址:https://arxiv.org/abs/1905.10437
- Tianqi Chen和Carlos Guestrin。“XGBoost:一种可扩展的树提升系统”。发表于:第22届ACM SIGKDD国际知识发现与数据挖掘会议论文集。KDD ’16。美国加利福尼亚州旧金山:计算机协会,2016年,页码785–794。ISBN: 9781450342322。DOI: 10.1145/2939672.2939785。网址:https://doi.org/10.1145/2939672.2939785(引用于第26页)。
您可以使用CPU或GPU在Google Colab中运行这些实验。
1. 安装包
import numpy as np
import pandas as pd
from datasetsforecast.hierarchical import HierarchicalData
from neuralforecast import NeuralForecast
from neuralforecast.models import NBEATS
from neuralforecast.losses.pytorch import GMM
from mlforecast import MLForecast
from mlforecast.utils import PredictionIntervals
import xgboost as xgb
#获取层次协调方法及评估
from hierarchicalforecast.methods import BottomUp, ERM, MinTrace
from hierarchicalforecast.utils import HierarchicalPlot
from hierarchicalforecast.core import HierarchicalReconciliation
from hierarchicalforecast.evaluation import scaled_crps
2. 加载层次数据集
这个详细的澳大利亚旅游数据集来自国家游客调查,由澳大利亚旅游研究机构管理,包含1998年至2016年间的555个月度时间序列,按地理位置和旅行目的组织。自然地理层级包括七个州,进一步分为27个区域和76个地区。旅行目的分类包括假期、探访朋友和亲戚(VFR)、商务及其他。MinT(Wickramasuriya等人,2019)等其他层次预测研究曾在过去使用该数据集。该数据集可以在MinT对账网页中访问,尽管还有其他来源可用。
地理区划 | 每个区划的系列数量 | 每个目的的系列数量 | 总计 |
---|---|---|---|
澳大利亚 | 1 | 4 | 5 |
各州 | 7 | 28 | 35 |
区域 | 27 | 108 | 135 |
地区 | 76 | 304 | 380 |
总计 | 111 | 444 | 555 |
= HierarchicalData.load('./data', 'TourismLarge')
Y_df, S_df, tags 'ds'] = pd.to_datetime(Y_df['ds']) Y_df[
Y_df.head()
unique_id | ds | y | |
---|---|---|---|
0 | TotalAll | 1998-01-01 | 45151.071280 |
1 | TotalAll | 1998-02-01 | 17294.699551 |
2 | TotalAll | 1998-03-01 | 20725.114184 |
3 | TotalAll | 1998-04-01 | 25388.612353 |
4 | TotalAll | 1998-05-01 | 20330.035211 |
可视化聚合矩阵。
= HierarchicalPlot(S=S_df, tags=tags)
hplot hplot.plot_summing_matrix()
将数据框拆分为训练/测试集。
def sort_hier_df(Y_df, S_df):
# 按字典顺序排序唯一标识符
= Y_df.unique_id.astype('category')
Y_df.unique_id = Y_df.unique_id.cat.set_categories(S_df.index)
Y_df.unique_id = Y_df.sort_values(by=['unique_id', 'ds'])
Y_df return Y_df
= sort_hier_df(Y_df, S_df) Y_df
= 12
horizon = Y_df.groupby('unique_id').tail(horizon)
Y_test_df = Y_df.drop(Y_test_df.index) Y_train_df
3. 拟合和预测模型
HierarchicalForecast与许多不同的机器学习模型兼容。这里,我们展示两个示例:
1. NBEATS,一种基于多层感知器的深度神经网络架构。
2. XGBRegressor,一种基于树的架构。
= np.arange(0, 100, 2)
level = [[50-lv/2, 50+lv/2] for lv in level]
qs = np.sort(np.concatenate(qs)/100)
quantiles
#拟合/预测 从NeuralForecast使用NBEATS
= NBEATS(h=horizon,
nbeats =2*horizon,
input_size=GMM(n_components=10, quantiles=quantiles),
loss='robust',
scaler_type=2000)
max_steps= NeuralForecast(models=[nbeats], freq='MS')
nf =Y_train_df)
nf.fit(df= nf.predict()
Y_hat_nf = nf.predict_insample(step_size=horizon)
insample_nf
#拟合/预测 来自MLForecast的XGBRegressor
= MLForecast(models=[xgb.XGBRegressor()],
mf ='MS',
freq=[1,2,12,24],
lags=['month'],
date_features
)=True, prediction_intervals=PredictionIntervals(n_windows=10, h=horizon))
mf.fit(Y_train_df, fitted= mf.predict(horizon, level=level).set_index('unique_id')
Y_hat_mf = mf.forecast_fitted_values() insample_mf
Y_hat_nf
ds | NBEATS | NBEATS-lo-98.0 | NBEATS-lo-96.0 | NBEATS-lo-94.0 | NBEATS-lo-92.0 | NBEATS-lo-90.0 | NBEATS-lo-88.0 | NBEATS-lo-86.0 | NBEATS-lo-84.0 | ... | NBEATS-hi-80.0 | NBEATS-hi-82.0 | NBEATS-hi-84.0 | NBEATS-hi-86.0 | NBEATS-hi-88.0 | NBEATS-hi-90.0 | NBEATS-hi-92.0 | NBEATS-hi-94.0 | NBEATS-hi-96.0 | NBEATS-hi-98.0 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
unique_id | |||||||||||||||||||||
TotalAll | 2016-01-01 | 44525.652344 | 21232.554688 | 26024.839844 | 27435.285156 | 28136.705078 | 28766.150391 | 29569.240234 | 30344.240234 | 31163.099609 | ... | 51812.953125 | 52171.792969 | 52628.562500 | 52890.750000 | 53160.312500 | 54025.210938 | 54451.109375 | 55651.007812 | 57686.027344 | 61461.066406 |
TotalAll | 2016-02-01 | 20819.431641 | 18020.289062 | 18314.943359 | 18480.269531 | 18612.464844 | 18695.382812 | 18807.242188 | 18912.910156 | 19027.187500 | ... | 22719.998047 | 22802.921875 | 22887.734375 | 23031.005859 | 23133.865234 | 23230.322266 | 23406.496094 | 23622.166016 | 23887.796875 | 24165.496094 |
TotalAll | 2016-03-01 | 23676.291016 | 19303.222656 | 19684.693359 | 19928.400391 | 20150.691406 | 20319.113281 | 20499.980469 | 20632.185547 | 20748.207031 | ... | 26215.312500 | 26291.195312 | 26402.853516 | 26578.257812 | 26848.179688 | 27054.107422 | 27310.746094 | 27723.867188 | 28211.294922 | 29011.082031 |
TotalAll | 2016-04-01 | 27978.587891 | 23936.988281 | 24329.892578 | 24532.740234 | 24735.703125 | 24902.812500 | 25165.074219 | 25256.669922 | 25489.455078 | ... | 30192.365234 | 30278.451172 | 30339.017578 | 30381.443359 | 30465.722656 | 30574.056641 | 30682.609375 | 30860.427734 | 31032.648438 | 31199.992188 |
TotalAll | 2016-05-01 | 22810.310547 | 20037.218750 | 20194.531250 | 20387.541016 | 20510.244141 | 20594.226562 | 20675.720703 | 20767.025391 | 20876.550781 | ... | 24975.916016 | 25149.097656 | 25240.177734 | 25401.996094 | 25577.400391 | 25800.574219 | 26132.904297 | 26559.906250 | 27273.566406 | 28567.857422 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
GBDOth | 2016-08-01 | 3.384338 | -31.891897 | -15.230768 | -1.954657 | -1.143704 | -0.994592 | -0.947800 | -0.884839 | -0.824748 | ... | 9.635074 | 10.517044 | 11.374988 | 12.784556 | 14.568413 | 22.581669 | 37.880905 | 51.512486 | 62.645977 | 81.495415 |
GBDOth | 2016-09-01 | 4.842800 | -41.682514 | -23.578377 | -6.487054 | -1.238661 | -1.024779 | -0.927368 | -0.856639 | -0.758568 | ... | 11.743630 | 12.755230 | 14.384780 | 16.579344 | 19.425726 | 36.155537 | 44.394543 | 60.144749 | 78.533859 | 101.363129 |
GBDOth | 2016-10-01 | 4.466261 | -21.124041 | -1.662255 | -1.157058 | -0.949211 | -0.857361 | -0.755605 | -0.699540 | -0.659419 | ... | 10.405193 | 11.605769 | 12.686687 | 14.218900 | 19.963741 | 26.705273 | 34.361160 | 51.898552 | 68.361931 | 89.458908 |
GBDOth | 2016-11-01 | 3.689114 | -22.615982 | -11.813770 | -1.530864 | -1.049960 | -0.922807 | -0.868391 | -0.802971 | -0.723462 | ... | 8.213260 | 8.837670 | 10.219457 | 12.300932 | 13.135829 | 23.325760 | 37.628525 | 43.993382 | 63.594315 | 84.825226 |
GBDOth | 2016-12-01 | 3.994789 | -38.856083 | -24.361221 | -7.503808 | -1.199999 | -1.003695 | -0.880594 | -0.788414 | -0.737489 | ... | 9.881157 | 11.406334 | 12.636977 | 15.831536 | 26.059269 | 32.270000 | 37.316460 | 51.765774 | 68.933304 | 91.916100 |
6660 rows × 102 columns
Y_hat_mf
ds | XGBRegressor | XGBRegressor-lo-98 | XGBRegressor-lo-96 | XGBRegressor-lo-94 | XGBRegressor-lo-92 | XGBRegressor-lo-90 | XGBRegressor-lo-88 | XGBRegressor-lo-86 | XGBRegressor-lo-84 | ... | XGBRegressor-hi-80 | XGBRegressor-hi-82 | XGBRegressor-hi-84 | XGBRegressor-hi-86 | XGBRegressor-hi-88 | XGBRegressor-hi-90 | XGBRegressor-hi-92 | XGBRegressor-hi-94 | XGBRegressor-hi-96 | XGBRegressor-hi-98 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
unique_id | |||||||||||||||||||||
TotalAll | 2016-01-01 | 43060.226562 | 38276.974483 | 38677.670530 | 39078.366577 | 39479.062624 | 39879.758671 | 40009.218877 | 40041.809140 | 40074.399403 | ... | 45980.873195 | 46013.463459 | 46046.053722 | 46078.643985 | 46111.234248 | 46240.694454 | 46641.390501 | 47042.086548 | 47442.782595 | 47843.478642 |
TotalAll | 2016-02-01 | 18008.296875 | 14687.962868 | 14813.816467 | 14939.670066 | 15065.523666 | 15191.377265 | 15247.400539 | 15278.484410 | 15309.568281 | ... | 20644.857726 | 20675.941597 | 20707.025469 | 20738.109340 | 20769.193211 | 20825.216485 | 20951.070084 | 21076.923684 | 21202.777283 | 21328.630882 |
TotalAll | 2016-03-01 | 20694.080078 | 16407.351099 | 16594.149043 | 16780.946987 | 16967.744931 | 17154.542875 | 17209.434677 | 17217.217141 | 17224.999606 | ... | 24147.595620 | 24155.378085 | 24163.160550 | 24170.943015 | 24178.725480 | 24233.617281 | 24420.415225 | 24607.213169 | 24794.011113 | 24980.809057 |
TotalAll | 2016-04-01 | 24474.349609 | 20859.120558 | 20978.737726 | 21098.354893 | 21217.972060 | 21337.589227 | 21380.287167 | 21395.513953 | 21410.740739 | ... | 27507.504906 | 27522.731693 | 27537.958479 | 27553.185266 | 27568.412052 | 27611.109991 | 27730.727159 | 27850.344326 | 27969.961493 | 28089.578660 |
TotalAll | 2016-05-01 | 19281.087891 | 15045.235849 | 15460.108990 | 15874.982131 | 16289.855271 | 16704.728412 | 16861.927796 | 16927.100837 | 16992.273878 | ... | 21439.555822 | 21504.728863 | 21569.901904 | 21635.074945 | 21700.247986 | 21857.447369 | 22272.320510 | 22687.193651 | 23102.066792 | 23516.939933 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
GBDOth | 2016-08-01 | 11.040442 | -0.720264 | 0.934877 | 2.590017 | 4.245157 | 5.900298 | 6.396993 | 6.479957 | 6.562921 | ... | 15.352035 | 15.435000 | 15.517964 | 15.600928 | 15.683892 | 16.180587 | 17.835727 | 19.490868 | 21.146008 | 22.801149 |
GBDOth | 2016-09-01 | 6.440751 | -0.275863 | -0.182214 | -0.088566 | 0.005083 | 0.098732 | 0.123376 | 0.123376 | 0.123376 | ... | 12.758126 | 12.758126 | 12.758126 | 12.758126 | 12.758126 | 12.782771 | 12.876419 | 12.970068 | 13.063716 | 13.157365 |
GBDOth | 2016-10-01 | 9.995112 | 2.407870 | 2.407870 | 2.407870 | 2.407870 | 2.407870 | 2.407870 | 2.407870 | 2.407870 | ... | 17.582355 | 17.582355 | 17.582355 | 17.582355 | 17.582355 | 17.582355 | 17.582355 | 17.582355 | 17.582355 | 17.582355 |
GBDOth | 2016-11-01 | 6.747566 | 2.791389 | 2.791389 | 2.791389 | 2.791389 | 2.791389 | 2.791389 | 2.791389 | 2.791389 | ... | 10.703742 | 10.703742 | 10.703742 | 10.703742 | 10.703742 | 10.703742 | 10.703742 | 10.703742 | 10.703742 | 10.703742 |
GBDOth | 2016-12-01 | 7.367904 | 2.349200 | 2.349200 | 2.349200 | 2.349200 | 2.349200 | 2.349200 | 2.349200 | 2.349200 | ... | 12.386609 | 12.386609 | 12.386609 | 12.386609 | 12.386609 | 12.386609 | 12.386609 | 12.386609 | 12.386609 | 12.386609 |
6660 rows × 102 columns
4. 预测的协调
通过最小的解析,我们可以使用不同的层次预测协调方法来协调原始输出预测。
= [
reconcilers ='closed'),
ERM(method
BottomUp(),'ols'),
MinTrace(
]= HierarchicalReconciliation(reconcilers=reconcilers)
hrec
= hrec.reconcile(Y_hat_df=Y_hat_nf, Y_df=insample_nf, S=S_df, tags=tags, level=level)
Y_rec_nf = hrec.reconcile(Y_hat_df=Y_hat_mf, Y_df=insample_mf, S=S_df, tags=tags, level=level) Y_rec_mf
5. 评估
为了进行评估,我们使用Rangapuram(2021)提出的CRPS的缩放变体,以测量预测分位数y_hat
与观测值y
之间的准确性。
\[ \mathrm{sCRPS}(\hat{F}_{\tau}, \mathbf{y}_{\tau}) = \frac{2}{N} \sum_{i} \int^{1}_{0} \frac{\mathrm{QL}(\hat{F}_{i,\tau}, y_{i,\tau})_{q}}{\sum_{i} | y_{i,\tau} |} dq \]
= ['NBEATS/BottomUp', 'NBEATS/MinTrace_method-ols', 'NBEATS/ERM_method-closed_lambda_reg-0.01']
rec_model_names_nf = ['XGBRegressor/BottomUp', 'XGBRegressor/MinTrace_method-ols', 'XGBRegressor/ERM_method-closed_lambda_reg-0.01']
rec_model_names_mf
= len(quantiles)
n_quantiles = len(S_df)
n_series
for name in rec_model_names_nf:
= [col for col in Y_rec_nf.columns if (name+'-lo') in col or (name+'-hi') in col]
quantile_columns = Y_rec_nf[quantile_columns].values
y_rec = Y_test_df['y'].values
y_test
= y_rec.reshape(n_series, horizon, n_quantiles)
y_rec = y_test.reshape(n_series, horizon)
y_test = scaled_crps(y=y_test, y_hat=y_rec, quantiles=quantiles)
scrps print("{:<50} {:.3f}".format(name+":", scrps))
for name in rec_model_names_mf:
= [col for col in Y_rec_mf.columns if (name+'-lo') in col or (name+'-hi') in col]
quantile_columns = Y_rec_mf[quantile_columns].values
y_rec = Y_test_df['y'].values
y_test
= y_rec.reshape(n_series, horizon, n_quantiles)
y_rec = y_test.reshape(n_series, horizon)
y_test = scaled_crps(y=y_test, y_hat=y_rec, quantiles=quantiles)
scrps print("{:<50} {:.3f}".format(name+":", scrps))
NBEATS/BottomUp: 0.129
NBEATS/MinTrace_method-ols: 0.129
NBEATS/ERM_method-closed_lambda_reg-0.01: 0.179
XGBRegressor/BottomUp: 0.134
XGBRegressor/MinTrace_method-ols: 0.178
XGBRegressor/ERM_method-closed_lambda_reg-0.01: 0.177
6. 可视化
= pd.concat([Y_df.set_index(['unique_id', 'ds']),
plot_nf 'ds', append=True)], axis=1)
Y_rec_nf.set_index(= plot_nf.reset_index('ds')
plot_nf
= pd.concat([Y_df.set_index(['unique_id', 'ds']),
plot_mf 'ds', append=True)], axis=1)
Y_rec_mf.set_index(= plot_mf.reset_index('ds') plot_mf
hplot.plot_series(='TotalVis',
series=plot_nf,
Y_df=['y', 'NBEATS', 'NBEATS/BottomUp', 'NBEATS/MinTrace_method-ols', 'NBEATS/ERM_method-closed_lambda_reg-0.01'],
models=[80]
level )
hplot.plot_series(='TotalVis',
series=plot_mf,
Y_df=['y', 'XGBRegressor', 'XGBRegressor/BottomUp', 'XGBRegressor/MinTrace_method-ols', 'XGBRegressor/ERM_method-closed_lambda_reg-0.01'],
models=[80]
level )
If you find the code useful, please ⭐ us on Github