# 此单元格将不会被渲染,但用于隐藏警告并限制显示的行数。
import warnings
"ignore")
warnings.filterwarnings(
import logging
'statsforecast').setLevel(logging.ERROR)
logging.getLogger(
import pandas as pd
'display.max_rows', 6) pd.set_option(
AutoTheta 模型
使用
Statsforecast
的AutoTheta 模型
的分步指南。
在此过程中,我们将熟悉主要的 StatsForecast
类以及一些相关方法,例如 StatsForecast.plot
、StatsForecast.forecast
和 StatsForecast.cross_validation
等。
目录
引言
在建模和预测过程中,当涉及大量时间序列时,开发准确、稳健和可靠的单变量时间序列预测方法非常重要。在工业环境中,处理大型产品线是非常普遍的,因此高效的销售和运营规划(S&OP)在很大程度上依赖于准确的预测方法。
Theta 方法 (Assimakopoulos & Nikolopoulos, 2000, 以下称 A&N) 应用于非季节性或去季节化的时间序列,去季节化通常通过乘法经典分解进行。该方法将原始时间序列分解为两个新的线性,通过所谓的 theta 系数表示,记作 \({\theta}_1\) 和 \({\theta}_2\),其中 \({\theta}_1, {\theta}_2 \in \mathbb{R}\),应用于数据的二阶差分。当 \({\theta}<1\) 时,二阶差分减小,导致更好地逼近时间序列的长期行为(Assimakopoulos, 1995)。如果 \({\theta}\) 等于零,新线为直线。当 \({\theta}>1\) 时,局部曲率增大,放大时间序列的短期波动(A&N)。所产生的新线称为 theta 线,这里记为 \(\text{Z}(\theta_1)\) 和 \(\text{Z}(\theta_2)\)。这些线具有与原始数据相同的均值和斜率,但局部曲率则根据 \(\theta\) 系数的值被过滤或增强。
换句话说,分解过程的优点在于利用数据中的信息,这些信息通常无法完全通过对原始时间序列的外推捕获和建模。theta 线可以被视为新的时间序列,并使用适当的预测方法分别进行外推。一旦每条 theta 线的外推完成,就会通过组合方案进行重构,以便计算原始时间序列的点预测。组合在预测文献中长期以来被视为一种有效的实践(例如,Clemen, 1989; Makridakis 和 Winkler, 1983; Petropoulos 等, 2014),因此其在 Theta 方法中的应用预计将带来更准确和稳健的预测。
Theta 方法在选择 theta 线的数量、theta 系数和外推方法方面非常灵活,并且能够将这些组合以获得强健的预测。然而,A&N 提出了一个简化版本,仅涉及使用两个 theta 线,具有预设的 \(\theta\) 系数,通过线性回归 (LR) 模型对 \(\theta_1 = 0\) 的 theta 线进行时间外推,对 \(\theta_2 = 2\) 的 theta 线进行简单指数平滑 (SES)。最终的预测是通过将两个 theta 线的预测以相等权重结合而成。
Theta 方法的性能也得到了其他实证研究的确认(例如 Nikolopoulos 等人,2012 年;Petropoulos 和 Nikolopoulos,2013 年)。此外,Hyndman 和 Billah(2003 年),以下简称 H&B,表明简单指数平滑与漂移模型(SES-d)是 Theta 方法简化版本的统计模型。最近,Thomakos 和 Nikolopoulos(2014 年)提供了更多的理论见解,而 Thomakos 和 Nikolopoulos(2015 年)为该方法在多变量时间序列中的应用推导出了新的理论公式,并研究了二元 Theta 方法预期优于单变量方法的条件。尽管取得了这些进展,我们认为 Theta 方法值得预测社区更多的关注,因为它简单且具有卓越的预测性能。
Theta 方法的一个关键方面是,根据定义,它是动态的。用户可以选择不同的 theta 线,并使用相等或不等的权重结合生成的预测。然而,AN 通过将 theta 系数固定为预定义值来限制这一重要特性。
本工作的贡献有四个方面。首先,我们通过最优选择描述序列短期波动的 theta 线扩展了 A&N 方法,同时维持长期组件。从两个 theta 线得到的预测使用适当的权重进行组合,确保原始时间序列的重组。其次,我们提供了新提出的模型、原始 Theta 方法和 SES-d 模型之间的理论和实践联系。第三,我们对该模型进行了进一步扩展,允许在每个时间段修订回归线(长期组件)。
Theta方法
原始Theta方法
最初,AN提出了theta线作为方程的解
\[ \begin{equation} \nabla^2 \text{Z}_t (\theta) =\theta \nabla^2 Y_t, \ \ t=3, \cdots, n \tag 1 \end{equation} \]
其中\(Y_1, \cdots Y_n\)是原始时间序列(非季节性或已去季节化),而\(\nabla\)是差分算子(即\(\nabla X_t = X_t - X_{t-1}\))。通过最小化\(\sum_{t=1}^n [Y_t -\text{Z}_t(\theta)]^2\)来获得\(\text{Z}_1\)和\(\text{Z}_2\)的初始值。然而,通过H&B获得了计算\(\text{Z}(\theta)\)的解析解,给出如下:
\[ \begin{equation} \text{Z}_t(\theta)=\theta Y_t +(1-\theta)(A_n +B_n t) \ \ t=1, \cdots n \tag 2 \end{equation} \]
其中\(\text{A}_n\)和\(\text{B}_n\)是对\(Y_1, \cdots Y_n\)与\(1, \cdots n\)进行简单线性回归的最小平方系数,如下所示:
\[ \begin{equation} \text{A}_n =\frac{1}{n} \sum_{t=1}^{n} Y_t -\frac{n+1}{2} \text{B}_n ; \ \ \text{B}_n= \frac{6}{n^2 -1} (\frac{2}{n} \sum_{t=1}^{n} tY_t -\frac{1+n}{n} \sum_{t=1}^{n} Y_t ) \tag 3 \end{equation} \]
从这个角度来看,theta线可以被解释为直接应用于数据的线性回归模型的函数。然而,注意到\(\text{A}_n\)和\(\text{B}_n\)只是原始数据的函数,而不是Theta方法的参数。
最后,Theta方法针对提前\(h\)步的预测是对\(\text{Z}(0)\)和\(\text{Z}(2)\)的外推结果的一个随意组合(50%-50%),分别由线性回归模型和简单指数平滑模型给出。我们将上述设置称为标准Theta方法(STheta)。
构建AN的STheta方法的步骤如下: 1. 去季节化:对时间序列进行统计显著性季节性行为的检验。如果时间序列是季节性的,则满足条件 \[|r_m|>q_{1-\alpha/2} \sqrt{\frac{1+2 \sum_{i=1}^{m-1} r_{i}^2}{n} }\]
其中 \(r_k\) 表示滞后 \(k\) 的自相关函数,\(m\) 是季节周期内的期数(例如,月数据的情况下为12),\(n\) 是样本大小,\(q\) 是标准正态分布的分位数函数,\((1-\alpha)\%\) 是置信水平。A&N 选择了90%的置信水平。如果时间序列被识别为季节性,则通过经典分解法进行去季节化,假设季节成分具有乘法关系。
分解:经过季节调整的时间序列被分解为两条theta线,即线性回归线 \(\text{Z}(0)\) 和 theta线 \(\text{Z}(2)\)。
外推:\(\text{Z}(0)\) 被外推为正常的线性回归线,而 \(\text{Z}(2)\) 则使用单指数平滑(SES)进行外推。
组合:最终的预测是使用等权重组合这两条theta线的预测。
重新季节化:如果在步骤1中识别出系列为季节性,则最终预测值乘以相应的季节性指数。
优化Theta方法的模型
假设时间序列 \(Y_1, \cdots Y_n\) 是非季节性的,或者已经使用乘法经典分解方法进行了季节性调整。
令 \(X_t\) 为两个theta线的线性组合,
\[ \begin{equation} X_t=\omega \text{Z}_t (\theta_1) +(1-\omega) \text{Z}_t (\theta_2) \tag 4 \end{equation} \]
其中 \(\omega \in [0,1]\) 是权重参数。假设 \(\theta_1 <1\) 且 \(\theta_2 \geq 1\),权重 \(\omega\) 可以由以下公式得出
\[ \begin{equation} \omega:=\omega(\theta_1, \theta_2)=\frac{\theta_2 -1}{\theta_2 -\theta_1} \tag 5 \end{equation} \]
从公式 (4) 和 (5) 可以清楚地看到 \(X_t=Y_t, \ t=1, \cdots n\),即权重的计算方式使得公式 (4) 能够重现原始序列。在附录 A 的定理 1中,我们证明了解的唯一性,并且未选择最佳权重(\(\omega\) 和 \(1-\omega\))所造成的误差与线性回归模型的误差成正比。因此,STheta 方法简单地通过设置 \(\theta_1=0\) 和 \(\theta_2=2\) 而得出,其中从公式 (5) 我们得到 \(\omega=0.5\)。因此,公式 (4) 和 (5) 使我们能够构建一个Theta模型的推广,该模型保持原始时间序列对任何theta线 \(\text{Z}_t (\theta_1)\) 和 \(\text{Z}_t (\theta_2)\) 的重新组合特性。
为了保持对长期成分的建模,并与STheta方法进行公正比较,在本工作中我们固定\(\theta_1=0\),专注于短期成分的优化,即\(\theta_2=0\)且\(\theta \geq 1\)。因此,\(\theta\)是目前唯一需要估计的参数。Theta分解现在给出为:
\[Y_t=(1-\frac{1}{\theta}) (\text{A}_n+\text{B}_n t)+ \frac{1}{\theta} \text{Z}_t (\theta), \ t=1, \cdots , n\]
在原点计算的\(h\)步预测给出为:
\[ \begin{equation} \hat Y_{n+h|n} = (1-\frac{1}{\theta}) [\text{A}_n+\text{B}_n (n+h)]+ \frac{1}{\theta} \tilde {\text{Z}}_{n+h|n} (\theta) \tag 6 \end{equation} \]
其中\(\tilde {\text{Z}}_{n+h|n} (\theta)=\tilde {\text{Z}}_{n+1|n} (\theta)=\alpha \sum_{i=0}^{n-1}(1-\alpha)^i \text{Z}_{n-i}(\theta)+(1-\alpha)^n \ell_{0}^{*}\)是通过SES模型对\(\text{Z}_t(\theta)\)的外推,\(\ell_{0}^{*} \in \mathbb{R}\)为初始水平参数,\(\alpha \in (0,1)\)为平滑参数。注意,对于\(\theta=2\),方程(6)对应于STheta算法的第4步。经过一些代数运算,我们可以写出:
\[ \begin{equation} \tilde {\text{Z}}_{n+1|n} (\theta)=\theta \ell{n}+(1-\theta) \{ \text{A}_n [1-(1-\alpha)^n] + \text{B}_n [n+(1-\frac{1}{\alpha}) [1-(1-\alpha)^n] ] \} \tag 7 \end{equation} \]
其中\(\ell_{t}=\alpha Y_t +(1-\alpha) \ell_{t-1}\),\(t=1, \cdots, n\),且\(\ell_{0}=\ell_{0}^{*}/\theta\)。
结合方程(6)和(7),我们建议四种随机方法。这些方法的不同之处在于参数\(\theta\)可以是固定的两个值或优化的,同时系数\(\text{A}_n\)和\(\text{B}_n\)可以是固定值或动态函数。为了构建状态空间模型,采用\(\mu_{t}\)作为原点\(t-1\)的一步预测,\(\varepsilon_{t}\)作为相应的加性误差,即如果\(\mu_{t}= \hat Y_{t|t-1}\),则\(\varepsilon_{t}=Y_t - \mu_{t}\)。我们假设\(\{ \varepsilon_{t} \}\)是一个均值为零、方差为\(\sigma^2\)的高斯白噪声过程。
优化和标准Theta模型
设 \(\text{A}_n\) 和 \(\text{B}_n\) 为对于所有 \(t=1, \ldots, n\) 的固定系数,使得方程 (6), (7) 配置如下的状态空间模型:
\[ \begin{equation} Y_t=\mu_{t}+\varepsilon_{t} \tag 8 \end{equation} \]
\[ \begin{equation} \mu_{t}=\ell_{t-1}+(1-\frac{1}{\theta}) \{(1-\alpha)^{t-1} \text{A}_n +[\frac{1-(1-\alpha)^t}{\alpha} \text{B}_n] \tag 9 \} \end{equation} \]
\[ \begin{equation} \ell_{t}=\alpha Y_t +(1-\alpha)\ell_{t-1} \tag{10} \end{equation} \]
其中参数 \(\ell_{0} \in \mathbb{R}\), \(\alpha \in (0,1)\) 和 \(\theta \in [1,\infty)\) 。参数 \(\theta\) 需要与 \(\alpha\) 和 \(\ell_{0}\) 一起估计。我们称其为优化的Theta模型(OTM)。
从原点 \(n\) 的 \(h\) 步预测为:
\[\hat Y_{n+h|n}=E[Y_{n+h}|Y_1,\cdots, Y_n]=\ell_{n}+(1-\frac{1}{\theta}) \{(1-\alpha)^n \text{A}_n +[(h-1) + \frac{1-(1-\alpha)^{n+1}}{\alpha}] \text{B}_n \}\]
这等同于方程 (6)。条件方差 \(\text{Var}[Y_{n+h}|Y_1, \cdots, Y_n]=[1+(h-1)\alpha^2]\sigma^2\) 可以很容易地从状态空间模型中计算得出。因此,\(Y_{n+h}\) 的 \((1-\alpha)\%\) 预测区间为:
\[\hat Y_{n+h|n} \ \pm \ q_{1-\alpha/2} \sqrt{[1+(h-1)\alpha^2 ]\sigma^2 }\]
当 \(\theta=2\) 时,OTM 复现了STheta方法的预测;在此之后,我们将特指这一情况称为标准Theta模型(STM)。在 附录A 的定理2中,我们证明了OTM 在数学上等价于 SES-d 模型。作为定理2的一个推论,STM 在数学上等价于 SES-d,其中 \(b=\frac{1}{2} \text{B}_n\)。因此,对于 \(\theta=2\),该推论也再次确认了H&B结果关于 STheta 与 SES-d 模型之间关系的结果。
加载库和数据
需要使用Statsforecast。安装方法请参见说明。
接下来,我们导入绘图库并配置绘图样式。
import pandas as pd
import scipy.stats as stats
import matplotlib.pyplot as plt
import seaborn as sns
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
'fivethirtyeight')
plt.style.use('lines.linewidth'] = 1.5
plt.rcParams[= {
dark_style 'figure.facecolor': '#212946',
'axes.facecolor': '#212946',
'savefig.facecolor':'#212946',
'axes.grid': True,
'axes.grid.which': 'both',
'axes.spines.left': False,
'axes.spines.right': False,
'axes.spines.top': False,
'axes.spines.bottom': False,
'grid.color': '#2A3459',
'grid.linewidth': '1',
'text.color': '0.9',
'axes.labelcolor': '0.9',
'xtick.color': '0.9',
'ytick.color': '0.9',
'font.size': 12 }
plt.rcParams.update(dark_style)
from pylab import rcParams
'figure.figsize'] = (18,7)
rcParams[
读取数据
= pd.read_csv("https://raw.githubusercontent.com/Naren8520/Serie-de-tiempo-con-Machine-Learning/main/Data/candy_production.csv")
df df.head()
observation_date | IPG3113N | |
---|---|---|
0 | 1972-01-01 | 85.6945 |
1 | 1972-02-01 | 71.8200 |
2 | 1972-03-01 | 66.0229 |
3 | 1972-04-01 | 64.5645 |
4 | 1972-05-01 | 65.0100 |
StatsForecast 的输入始终是一个长格式的数据框,包含三列:unique_id、ds 和 y:
unique_id
(字符串、整数或类别)表示序列的标识符。ds
(日期戳)列应采用 Pandas 预期的格式,理想情况下为 YYYY-MM-DD 的日期格式或 YYYY-MM-DD HH:MM:SS 的时间戳格式。y
(数值)表示我们希望预测的测量值。
"unique_id"]="1"
df[=["ds", "y", "unique_id"]
df.columns df.head()
ds | y | unique_id | |
---|---|---|---|
0 | 1972-01-01 | 85.6945 | 1 |
1 | 1972-02-01 | 71.8200 | 1 |
2 | 1972-03-01 | 66.0229 | 1 |
3 | 1972-04-01 | 64.5645 | 1 |
4 | 1972-05-01 | 65.0100 | 1 |
print(df.dtypes)
ds object
y float64
unique_id object
dtype: object
我们可以看到我们的时间变量 (ds)
是以对象格式表示的,我们需要将其转换为日期格式。
"ds"] = pd.to_datetime(df["ds"]) df[
使用plot方法探索数据
使用StatsForecast类中的plot方法绘制一些系列。该方法打印数据集中的随机系列,对于基本的探索性数据分析(EDA)非常有用。
from statsforecast import StatsForecast
="matplotlib") StatsForecast.plot(df, engine
自相关图
= plt.subplots(nrows=1, ncols=2)
fig, axs
"y"], lags=60, ax=axs[0],color="fuchsia")
plot_acf(df[0].set_title("Autocorrelation");
axs[
"y"], lags=60, ax=axs[1],color="lime")
plot_pacf(df[1].set_title('Partial Autocorrelation')
axs[
; plt.show()
将数据分为训练集和测试集
让我们将数据分为不同的集合 1. 用于训练我们的 AutoTheta
模型的数据 2. 用于测试我们模型的数据
对于测试数据,我们将使用最后12个月的数据来测试和评估我们模型的性能。
= df[df.ds<='2016-08-01']
train = df[df.ds>'2016-08-01'] test
train.shape, test.shape
((536, 3), (12, 3))
现在让我们绘制训练数据和测试数据。
="ds", y="y", label="Train", linewidth=3, linestyle=":")
sns.lineplot(train,x="ds", y="y", label="Test")
sns.lineplot(test, x"Candy Production")
plt.ylabel("Month")
plt.xlabel( plt.show()
AutoTheta的实现与StatsForecast
若要了解AutoTheta模型
函数的参数,以下列出了相关内容。有关更多信息,请访问文档。
season_length : int
每单位时间的观察次数。例如:24小时数据。
decomposition_type : str
季节分解类型,'multiplicative'(默认)或'additive'。
model : str
控制Theta模型。默认情况下搜索最佳模型。
alias : str
模型的自定义名称。
加载库
from statsforecast import StatsForecast
from statsforecast.models import AutoTheta
实例化模型
导入并实例化模型。设置参数有时很棘手。大师Rob Hyndmann关于季节周期的文章可能会有所帮助。
自动选择最佳的Theta(标准Theta模型 ('STM')
,优化Theta模型 ('OTM')
,动态标准Theta模型 ('DSTM')
,动态优化Theta模型 ('DOTM')
)模型,使用均方误差(mse)作为依据。
= 12 # 月度数据
season_length = len(test) # 预测数量
horizon
# 我们称之为即将使用的模型
= [AutoTheta(season_length=season_length,
models ="additive",
decomposition_type="STM")]
model
我们通过实例化一个新的 StatsForecast 对象来适配模型,使用以下参数:
models:模型列表。从模型中选择您想要的模型并导入。
freq:
一个字符串,表示数据的频率。(请参见 pandas 的可用频率。)n_jobs:
n_jobs:int,表示并行处理使用的作业数量,使用 -1 表示所有核心。fallback_model:
如果某个模型失败,则使用的备用模型。
任何设置都会传递给构造函数。然后您调用其 fit 方法,并传入历史数据框。
= StatsForecast(df=train,
sf =models,
models='MS',
freq=-1) n_jobs
拟合模型
sf.fit()
StatsForecast(models=[AutoTheta])
让我们来看一下我们的Theta模型的结果。我们可以通过以下指令观察它:
=sf.fitted_[0,0].model_
result result
{'mse': 100.64880011735451,
'amse': array([19.1126859 , 31.8559999 , 38.50771628]),
'fit': results(x=array([258.45065328, 0.7664297 ]), fn=100.57831804495909, nit=32, simplex=array([[250.37338839, 0.76970741],
[232.0391584 , 0.76429422],
[258.45065328, 0.7664297 ]])),
'residuals': array([ 2.10810474e+00, -1.13894116e+01, -9.49139749e+00, -9.82041058e+00,
-1.13263190e+01, -9.25797072e+00, -8.56619271e+00, -8.99232281e+00,
-1.98586635e+00, 3.14569304e+01, 1.98520670e+01, 2.04962612e+01,
4.98121093e+00, -1.08735318e+01, -1.12328034e+01, -8.08115311e+00,
-9.98197694e+00, -8.39937079e+00, -1.25789510e+01, -1.05952806e+01,
8.47229675e-01, 2.25644626e+01, 2.54401518e+01, 1.73989709e+01,
2.40287622e+00, -2.53475414e+00, -8.00590851e+00, -1.79241485e+01,
-6.36590415e+00, -5.76986711e+00, -2.26766747e+01, -8.95260757e+00,
-7.19719238e+00, 2.74032221e+01, 2.21368465e+01, 6.43171918e+00,
-3.51755110e+00, -1.31441914e+01, -6.13166000e+00, 1.51512233e+00,
-8.05776753e+00, -8.59603097e+00, -1.08617848e+01, -6.72940024e+00,
-6.24861877e+00, 2.85996831e+01, 2.98048000e+01, 1.90032272e+01,
5.07597932e+00, -9.59170287e+00, -1.64521040e+01, -7.52212480e+00,
-5.16540017e+00, -1.27924605e+01, -9.68434388e+00, -8.76758699e+00,
-8.27471729e-01, 3.08424000e+01, 2.47947373e+01, 2.35867234e+01,
3.75665137e+00, -4.47305376e+00, -1.48000353e+01, -1.08431533e+01,
-1.01249948e+01, -1.12379772e+01, -1.28624668e+01, -9.47780366e+00,
-2.17961621e-01, 2.49398623e+01, 1.66027796e+01, 2.62581226e+01,
-1.94878910e+00, -8.10877623e+00, -6.93183588e+00, -6.80707230e+00,
-1.17809900e+01, -1.05320691e+01, -1.59715820e+01, -9.07599954e+00,
6.11989624e-01, 2.24925204e+01, 2.57389548e+01, 2.38907624e+01,
4.99776573e+00, -1.07054647e+01, -7.24194335e+00, -1.17412089e+01,
-1.10031569e+01, -9.10138933e+00, -1.62277217e+01, -1.02585273e+01,
-2.79431182e+00, 1.96746062e+01, 2.40620697e+01, 2.00041945e+01,
8.38672864e-01, -3.01704956e-01, -1.10576385e+01, -1.76502427e+01,
-4.79852872e+00, -7.74057037e+00, -1.55628715e+01, -6.19663669e+00,
-4.85268042e+00, 2.17819362e+01, 2.48075797e+01, 2.16186239e+01,
9.21215945e+00, -1.71191246e+00, -1.38314202e+01, -9.44161648e+00,
-6.35863859e+00, -1.10470634e+01, -1.41408696e+01, -9.60040139e+00,
-4.80959573e+00, 3.41173932e+01, 2.02685750e+01, 1.65177476e+01,
1.45004512e+00, -6.65007227e-01, -1.11027908e+01, -1.82545880e+01,
-1.08637859e+01, -9.67573744e+00, -1.22946690e+01, -1.02064789e+01,
-2.94225752e+00, 3.21840486e+01, 2.21586037e+01, 2.09073985e+01,
-2.49861035e-01, -6.05605693e+00, -1.16741817e+01, -1.31096440e+01,
-1.07043792e+01, -1.25489002e+01, -9.16715967e+00, -7.70278796e+00,
-2.55656615e+00, 2.69936390e+01, 1.62042825e+01, 1.67614454e+01,
8.62186772e+00, -3.51518872e+00, -9.27420876e+00, -1.15442800e+01,
-9.96136035e+00, -1.17898547e+01, -1.13147631e+01, -7.10440289e+00,
-1.10170687e+00, 2.60646470e+01, 2.32687962e+01, 1.82272075e+01,
3.98792429e+00, -7.64233456e+00, -1.07945924e+01, -1.16024033e+01,
-1.10645356e+01, -1.33282252e+01, -1.15534869e+01, -6.76285862e+00,
3.93787288e+00, 2.37018466e+01, 2.07922165e+01, 2.37645525e+01,
7.00181250e-01, -1.59605643e+00, -1.62277551e+01, -1.51068272e+01,
-1.01377617e+01, -1.13639555e+01, -1.38275875e+01, -5.87092406e+00,
3.43469843e+00, 2.82932152e+01, 2.39510183e+01, 1.71053567e+01,
6.00991052e-01, -7.61227344e-01, -1.18686652e+01, -1.51989720e+01,
-1.23352889e+01, -1.09931328e+01, -1.34086750e+01, -4.52127936e+00,
2.09363456e+00, 3.13825833e+01, 2.43980039e+01, 1.89899596e+01,
-7.55701831e+00, -2.76896118e-01, -6.52574213e+00, -1.67167273e+01,
-1.17498904e+01, -7.68050375e+00, -5.60844064e+00, -2.79087748e+00,
-2.92096472e-01, 2.31896512e+01, 1.70158839e+01, 1.84177146e+01,
-3.39879932e-01, 1.31241678e+00, -9.65552310e+00, -1.30840504e+01,
-1.33540007e+01, -9.72077532e+00, -1.09022915e+01, -4.49636293e+00,
-6.88545386e-01, 1.88878505e+01, 2.15227108e+01, 2.32009713e+01,
-5.72605410e+00, 1.87746304e+00, -6.95944417e+00, -1.41944243e+01,
-1.25398538e+01, -8.09461371e+00, -5.46316779e+00, -4.73324801e+00,
1.12162623e+00, 1.61183549e+01, 2.63470348e+01, 2.28827932e+01,
-6.75326627e+00, 4.34024335e+00, -6.61711421e+00, -1.64533692e+01,
-1.44473755e+01, -4.85575266e+00, -1.14659676e+01, -1.83411691e+00,
-3.17491990e+00, 1.22586080e+01, 2.19162170e+01, 1.62630875e+01,
-1.99943217e+00, 2.59143066e-03, -8.89995746e+00, -1.10976710e+01,
-1.43864442e+01, -9.48222086e+00, -1.06785679e+01, -7.24340992e+00,
2.15093059e+00, 1.53607646e+01, 2.06126872e+01, 1.96076193e+01,
3.03104513e+00, -8.52381714e-02, -8.52357291e+00, -1.33461539e+01,
-1.37600223e+01, -6.08840914e+00, -8.32367614e+00, -3.02117355e+00,
4.08613501e-01, 1.63346128e+01, 1.76259441e+01, 1.75724074e+01,
1.52688580e+00, -2.23616547e+00, -3.82137115e+00, -1.61943645e+01,
-1.55739800e+01, -6.10489281e+00, -6.56542552e+00, -3.79160289e+00,
1.79366877e+00, 1.37690217e+01, 1.71703998e+01, 2.12968995e+01,
2.55881699e+00, -5.89333546e+00, -5.43867789e+00, -9.34441587e+00,
-1.23296390e+01, -7.43701646e+00, -9.59827596e+00, -6.98198281e+00,
-7.94913086e-01, 1.30601029e+01, 2.03392159e+01, 2.52824428e+01,
-3.95418141e+00, 2.43162639e+00, -3.09610920e+00, -1.49779650e+01,
-1.07287620e+01, -8.40149865e+00, -1.18887461e+01, -1.74756743e+00,
2.17909199e+00, 1.20038424e+01, 2.42508105e+01, 2.34572744e+01,
-5.17568502e+00, -1.96581848e-01, -4.18458181e+00, -1.55119014e+01,
-1.38833806e+01, -8.29522400e+00, -1.30003229e+01, -1.67000232e-01,
9.35163928e-01, 1.47273988e+01, 2.29308492e+01, 2.17103709e+01,
3.68218624e+00, 2.64748987e-01, -7.34442764e+00, -1.25122409e+01,
-1.14503466e+01, -8.19533607e+00, -1.15456959e+01, -2.81694144e+00,
-1.50157905e+00, 1.14252476e+01, 2.08253666e+01, 1.93274968e+01,
7.94221350e-01, -5.10391223e-01, -8.74258215e+00, -9.01560883e+00,
-1.00192385e+01, -1.10908755e+01, -1.09129047e+01, -6.64424323e+00,
-1.50482360e+00, 1.46897903e+01, 1.73829646e+01, 2.23508538e+01,
8.64908630e+00, 6.22671094e-01, -6.68012804e+00, -5.70808134e+00,
-1.80391994e+01, -7.97570029e+00, -1.19962951e+01, -5.55858845e+00,
2.35415415e+00, 1.17526341e+01, 1.54009299e+01, 2.21564078e+01,
3.90927294e+00, 2.21699393e+00, -3.80724429e+00, -1.09345637e+01,
-1.37938445e+01, -1.00726067e+01, -1.19963695e+01, -5.40000947e+00,
-1.51911075e+00, 1.69895894e+00, 1.74367915e+01, 2.04883209e+01,
7.55305128e+00, 7.29570557e-01, -5.09536064e+00, -1.29493288e+01,
-1.53454354e+01, -2.46711146e+00, -1.01903530e+01, -4.03697109e+00,
-3.08084248e+00, 3.86928003e+00, 1.92764138e+01, 1.55958083e+01,
7.35560975e+00, 1.85905807e+00, -5.61642383e-01, -1.23394878e+01,
-9.90369395e+00, -7.50968535e+00, -1.83651427e+01, -2.77916462e+00,
-1.07805579e+00, 8.15877069e+00, 2.33477145e+01, 1.69720399e+01,
6.19355475e+00, 4.92033569e+00, -1.36452248e+01, -1.10382213e+01,
-4.45625718e+00, -1.37976279e+01, -1.12070256e+01, -1.28293835e+00,
1.02619080e-01, 1.16373460e+01, 1.73964054e+01, 1.64050950e+01,
1.32632343e+01, 4.44789636e+00, -1.66636731e+01, -1.04932464e+01,
-7.27536875e+00, -1.52095873e+01, -8.33331642e+00, -6.12562858e+00,
-6.19892773e-01, 1.73375832e+01, 1.71076149e+01, 2.30092407e+01,
-1.39793491e+00, 1.20108252e+00, -1.01506302e+01, -9.35708666e+00,
-1.72524976e+01, -1.33257495e+01, -1.11436034e+01, -1.07822300e+00,
2.29722831e+00, 1.15489407e+01, 1.72661569e+01, 2.11762694e+01,
9.51783992e+00, -1.02191544e+00, -5.14895468e+00, -2.05301462e+01,
-1.56429874e+01, -1.60412132e+01, -1.50915598e+01, -2.94815391e+00,
4.61947153e+00, 6.94204539e+00, 1.79378224e+01, 2.19333492e+01,
8.01926611e+00, -3.09873730e+00, -6.33384070e+00, -1.29667977e+01,
-1.54450147e+01, -1.27736740e+01, -1.46733540e+01, -8.76926667e+00,
8.56843032e+00, 1.28259081e+01, 1.86473179e+01, 5.73666716e+00,
4.33460410e+00, 2.08833633e+00, -3.96959197e+00, -1.29223836e+01,
-1.19550430e+01, -1.27279222e+01, -8.02537455e+00, -3.92330203e+00,
7.09140653e+00, 2.42153186e+01, 1.28924490e+01, 1.79712039e+01,
2.89523345e+00, 1.30474561e+00, -7.77941964e+00, -1.04361436e+01,
-1.14357282e+01, -1.23868253e+01, -3.73409886e+00, 6.47317969e-01,
5.14176714e+00, 1.16621419e+01, 8.00349658e+00, 1.83900836e+01,
3.46846826e+00, 2.29413411e+00, -4.06962429e+00, -8.55164816e+00,
-1.76399687e+01, -1.50423464e+01, -1.13765493e+01, -9.17973348e+00,
-4.22253840e+00, 2.19090146e+01, 1.90170613e+01, 1.80606320e+01,
4.08981908e+00, 2.02346575e+00, -5.45474178e+00, -1.38725735e+01,
-1.50622767e+01, -1.15367785e+01, -7.55445662e+00, -1.77510786e+00,
9.46336210e+00, 4.88813442e+00, 1.61490921e+01, 1.93212523e+01,
1.03075606e+01, -6.46757556e-01, -5.79533411e-01, -1.35917688e+01,
-1.62148895e+01, -1.29823914e+01, -1.02149059e+01, -3.24210991e+00,
3.05411548e-01, 1.19385124e+01, 2.08979522e+01, 2.19927488e+01,
1.32364111e+00, 1.68626682e+00, -3.52030260e+00, -1.50337393e+01,
-1.75865889e+01, -1.23980830e+01, -1.19670278e+01, -1.59575155e+00,
4.32015198e+00, 1.39461341e+01, 2.63901700e+01, 2.11431702e+01,
1.19960846e+00, 1.22769642e+00, -3.12850968e+00, -1.23388318e+01,
-1.66429415e+01, -9.08277175e+00, -7.92637109e+00, 2.43702710e+00,
-3.53211257e+00, 1.00606809e+01, 1.39608449e+01, 1.44689445e+01,
6.50770870e+00, 3.13941359e+00, -4.89894849e-01, -1.05833253e+01,
-1.34863092e+01, -1.20763816e+01, -1.00738931e+01, -9.39207509e+00]),
'm': 12,
'states': array([[7.33548965e+01, 8.30544205e+01, 8.40569077e+01, 6.14692122e-02,
8.35863953e+01],
[7.31818390e+01, 7.80917587e+01, 8.40569077e+01, 6.14692122e-02,
8.32094116e+01],
[7.38093872e+01, 7.67280502e+01, 8.40569077e+01, 6.14692122e-02,
7.55142975e+01],
...,
[1.12984993e+02, 1.00517860e+02, 8.40569077e+01, 6.14692122e-02,
1.14480782e+02],
[1.14049675e+02, 1.00543762e+02, 8.40569077e+01, 6.14692122e-02,
1.13025093e+02],
[1.10946037e+02, 1.00561401e+02, 8.40569077e+01, 6.14692122e-02,
1.14089775e+02]], dtype=float32),
'par': {'initial_smoothed': 41.527209666340006,
'alpha': 0.7664297044277077,
'theta': 2.0},
'n': 536,
'modeltype': 'STM',
'mean_y': 100.56138830499272,
'decompose': True,
'decomposition_type': 'additive',
'seas_forecast': {'mean': array([ 0.08977811, 18.09442 , 20.248487 , 19.430647 ,
2.6400807 , -1.3090991 , -7.977731 , -12.3264065 ,
-12.027774 , -10.136967 , -11.4229355 , -5.3025 ],
dtype=float32)},
'fitted': array([ 83.58639526, 83.20941162, 75.51429749, 74.38491058,
76.33631897, 76.90467072, 77.60909271, 79.82932281,
77.03206635, 75.4719696 , 85.744133 , 85.47103882,
86.31848907, 88.1435318 , 80.84380341, 78.37975311,
81.66417694, 83.26287079, 84.62535095, 83.77008057,
79.74427032, 80.35553741, 83.81224823, 87.82202911,
86.29562378, 86.14455414, 85.23590851, 85.24504852,
80.98550415, 85.35566711, 88.73347473, 80.13900757,
77.37219238, 71.81797791, 78.98325348, 80.46128082,
70.5292511 , 65.84059143, 56.80056 , 58.24617767,
68.88546753, 71.95893097, 73.17068481, 73.63150024,
72.56861877, 67.74141693, 75.82369995, 83.17867279,
82.88182068, 84.77950287, 78.46220398, 71.9979248 ,
75.71080017, 81.00106049, 78.99654388, 80.35978699,
77.73477173, 77.0625 , 86.86366272, 90.37877655,
93.59484863, 94.48135376, 92.08713531, 86.88905334,
88.05659485, 89.54567719, 88.73256683, 87.66000366,
84.49066162, 84.28553772, 89.56282043, 86.79937744,
92.0628891 , 88.57657623, 83.39583588, 84.2281723 ,
88.48908997, 88.70896912, 88.43688202, 84.98139954,
82.12001038, 82.55097961, 85.95254517, 90.19133759,
93.64043427, 95.47816467, 88.30724335, 88.90190887,
89.38115692, 90.19718933, 91.0216217 , 87.36982727,
83.60211182, 81.4223938 , 82.66423035, 85.61780548,
86.08812714, 84.73820496, 85.54103851, 83.21124268,
79.16162872, 84.73307037, 86.6004715 , 83.45823669,
82.80368042, 79.04636383, 81.90332031, 85.42827606,
87.13594055, 92.20371246, 91.92572021, 87.47001648,
89.71173859, 94.08746338, 93.42066956, 91.36830139,
88.10499573, 84.38070679, 96.69192505, 96.73805237,
94.53625488, 93.65490723, 94.17929077, 91.814888 ,
87.30208588, 88.22493744, 88.60916901, 87.97177887,
84.24395752, 81.95085144, 92.78029633, 94.27500153,
95.43756104, 93.25335693, 89.64588165, 86.84354401,
86.27397919, 87.31900024, 85.50115967, 87.26078796,
85.45186615, 83.45436096, 90.30571747, 87.23685455,
85.22183228, 89.83718872, 88.17710876, 87.21417999,
87.84436035, 89.45885468, 88.22276306, 88.33640289,
86.98610687, 86.10365295, 92.24300385, 94.58859253,
93.69697571, 94.76073456, 89.93749237, 87.80930328,
88.39493561, 89.16392517, 86.74878693, 86.67945862,
85.59092712, 88.57095337, 92.89938354, 93.34684753,
96.69921875, 95.24315643, 95.05395508, 88.7616272 ,
86.66136169, 88.14065552, 87.23098755, 85.41872406,
85.01380157, 87.60818481, 95.45558167, 98.32404327,
96.57260895, 95.04052734, 95.49116516, 92.53977203,
90.36888885, 90.1639328 , 89.53847504, 88.04727936,
88.67676544, 90.24331665, 100.45849609, 103.66954041,
103.36251831, 95.57789612, 96.39974213, 97.54332733,
94.20919037, 94.45290375, 96.36634064, 100.85347748,
102.80919647, 102.54724884, 106.48311615, 104.0362854 ,
103.29067993, 101.03748322, 103.0774231 , 101.82225037,
101.27230072, 100.28657532, 100.6362915 , 101.06606293,
101.71464539, 101.14884949, 101.78768921, 102.79502869,
105.7154541 , 99.33413696, 101.80714417, 102.61832428,
101.21735382, 100.85561371, 102.45166779, 107.05014801,
107.51717377, 108.33874512, 106.85496521, 111.55980682,
114.23636627, 107.06775665, 111.42831421, 112.5018692 ,
109.3695755 , 107.54585266, 111.62426758, 111.62201691,
114.3110199 , 111.83959198, 107.39758301, 108.70651245,
106.30953217, 102.78440857, 103.82045746, 103.14437103,
104.11684418, 102.33982086, 102.87236786, 103.47360992,
102.01676941, 103.62723541, 101.56281281, 101.87268066,
102.03905487, 102.36943817, 103.33817291, 102.95055389,
102.19972229, 100.90280914, 104.03647614, 106.44257355,
108.2217865 , 108.49688721, 107.1788559 , 105.19959259,
103.8061142 , 102.98366547, 102.30387115, 105.52016449,
102.58638 , 99.89919281, 103.02022552, 106.77390289,
107.96263123, 109.29927826, 106.01490021, 103.68650055,
105.14758301, 105.11603546, 101.63327789, 103.61001587,
105.92623901, 105.72561646, 107.82567596, 109.25488281,
107.99841309, 107.35109711, 103.52338409, 103.62365723,
108.13938141, 103.11607361, 106.0138092 , 109.78596497,
107.78446198, 108.81079865, 110.17164612, 109.84536743,
112.60070801, 114.23275757, 109.5954895 , 112.69372559,
115.81058502, 109.85108185, 110.73448181, 113.67240143,
111.2616806 , 109.870224 , 111.31252289, 110.13430023,
114.10103607, 114.77970123, 112.22985077, 114.31642914,
116.09441376, 116.92385101, 118.16082764, 118.67694092,
118.56524658, 119.03853607, 120.55739594, 120.49404144,
122.42977905, 121.24085236, 116.16013336, 116.63300323,
116.58467865, 115.20069122, 115.84358215, 115.28810883,
117.8563385 , 119.42647552, 118.72610474, 119.14774323,
118.1501236 , 116.95870972, 114.3800354 , 112.2145462 ,
114.4834137 , 119.11962891, 120.63092804, 121.65618134,
126.75939941, 122.18280029, 123.86999512, 123.46128845,
123.29574585, 125.06196594, 120.2321701 , 116.54759216,
118.66742706, 119.67090607, 122.40414429, 125.63126373,
126.72874451, 125.40590668, 125.48596954, 125.07720947,
125.03321075, 123.83084106, 111.29560852, 109.17137909,
110.01274872, 113.80892944, 115.40216064, 117.64202881,
117.19533539, 114.68331146, 120.592453 , 121.56787109,
122.56854248, 120.16921997, 109.29738617, 108.58309174,
105.67469025, 109.31954193, 111.77844238, 117.49308777,
117.51379395, 119.17248535, 121.21684265, 115.92686462,
117.89155579, 117.02722931, 109.44298553, 111.84906006,
109.99544525, 112.74966431, 117.55482483, 113.24182129,
114.25985718, 120.09362793, 117.31872559, 117.51493835,
120.62638092, 120.66695404, 115.74879456, 113.59360504,
111.3054657 , 119.47810364, 123.9211731 , 117.2947464 ,
118.73046875, 122.40358734, 118.54651642, 120.94522858,
120.34509277, 119.83191681, 119.28258514, 116.90605927,
119.67953491, 116.61541748, 118.57003021, 116.93538666,
119.24189758, 115.26824951, 112.85500336, 113.099823 ,
116.36817169, 118.09075928, 113.10484314, 110.84983063,
112.21846008, 117.52051544, 117.77135468, 119.97014618,
113.71328735, 110.9732132 , 106.47875977, 103.69775391,
105.53292847, 109.03535461, 100.51857758, 98.77835083,
100.72723389, 104.8807373 , 103.5398407 , 104.83049774,
104.37041473, 101.78207397, 99.79195404, 97.33146667,
94.70516968, 101.23419189, 97.22698212, 96.03053284,
85.5657959 , 86.89526367, 89.52989197, 92.63258362,
92.20654297, 92.29302216, 90.33797455, 92.97270203,
94.06049347, 99.45748138, 104.17945099, 98.57229614,
97.48446655, 97.71075439, 99.74481964, 99.92754364,
101.4070282 , 101.89152527, 100.19789886, 106.12158203,
110.71243286, 114.61515808, 109.71600342, 100.36181641,
99.59503174, 100.26066589, 103.05302429, 106.07904816,
109.00286865, 104.7322464 , 101.0033493 , 101.06963348,
98.1287384 , 94.85438538, 97.80873871, 96.89566803,
95.87638092, 97.01823425, 99.60314178, 101.56757355,
100.41327667, 98.1182785 , 97.07615662, 100.07180786,
102.8060379 , 110.02096558, 99.93000793, 96.81884766,
96.76573944, 102.67305756, 103.21143341, 108.91236877,
107.97328949, 104.79489136, 102.64480591, 103.60140991,
105.21128845, 105.4072876 , 100.71994781, 101.24845123,
103.24285889, 102.26463318, 104.5911026 , 108.03813934,
105.99388885, 101.76418304, 100.0619278 , 99.67565155,
102.54734802, 105.82036591, 102.67173004, 107.40962982,
108.75289154, 107.67960358, 109.65460968, 113.40193176,
113.42314148, 109.91667175, 110.75537109, 113.4659729 ,
119.42851257, 116.68331909, 110.55675507, 105.76845551,
101.9963913 , 104.99138641, 108.43159485, 114.20122528,
115.56790924, 114.48078156, 113.02509308, 114.08977509])}
让我们现在可视化我们模型的残差。
正如我们所看到的,上面获得的结果输出为一个字典,提取字典中的每个元素,我们将使用.get()
函数来提取该元素,然后将其保存到pd.DataFrame()
中。
=pd.DataFrame(result.get("residuals"), columns=["residual Model"])
residual residual
residual Model | |
---|---|
0 | 2.108105 |
1 | -11.389412 |
2 | -9.491397 |
... | ... |
533 | -12.076382 |
534 | -10.073893 |
535 | -9.392075 |
536 rows × 1 columns
= plt.subplots(nrows=2, ncols=2)
fig, axs
=axs[0,0])
residual.plot(ax0,0].set_title("Residuals");
axs[
=axs[0,1]);
sns.distplot(residual, ax0,1].set_title("Density plot - Residual");
axs[
"residual Model"], dist="norm", plot=axs[1,0])
stats.probplot(residual[1,0].set_title('Plot Q-Q')
axs[
=35, ax=axs[1,1],color="fuchsia")
plot_acf(residual, lags1,1].set_title("Autocorrelation");
axs[
; plt.show()
预测方法
如果您希望在生产环境中处理多个系列或模型时提高速度,我们建议使用 StatsForecast.forecast
方法,而不是 .fit
和 .predict
。
主要区别在于 .forecast
不会存储拟合值,并且在分布式环境中具有很高的可扩展性。
预测方法接受两个参数:预测下一个 h
(时间范围)和 level
。
h (int):
表示预测未来 h 步。在这种情况下,是 12 个月。level (list of floats):
这个可选参数用于概率预测。设置您预测区间的水平(或置信百分位数)。例如,level=[90]
表示模型预计真实值在该区间内的概率为 90%。
此处的预测对象是一个新的数据框,包含一个模型名称列和 y hat 值,以及不确定性区间的列。根据您的计算机,这一步应该大约需要 1 分钟。(如果您希望将时间缩短到几秒钟,请移除像 ARIMA
和 Theta
这样的 AutoModels。)
# 预测
= sf.forecast(horizon, fitted=True)
Y_hat
Y_hat
ds | AutoTheta | |
---|---|---|
unique_id | ||
1 | 2016-09-01 | 111.075912 |
1 | 2016-10-01 | 129.111282 |
1 | 2016-11-01 | 131.296082 |
... | ... | ... |
1 | 2017-06-01 | 101.125748 |
1 | 2017-07-01 | 99.870514 |
1 | 2017-08-01 | 106.021683 |
12 rows × 2 columns
=sf.forecast_fitted_values()
values values.head()
ds | y | AutoTheta | |
---|---|---|---|
unique_id | |||
1 | 1972-01-01 | 85.694504 | 83.586395 |
1 | 1972-02-01 | 71.820000 | 83.209412 |
1 | 1972-03-01 | 66.022903 | 75.514297 |
1 | 1972-04-01 | 64.564499 | 74.384911 |
1 | 1972-05-01 | 65.010002 | 76.336319 |
StatsForecast.plot(values)
添加95%置信区间与预测方法
=horizon, level=[95]) sf.forecast(h
ds | AutoTheta | AutoTheta-lo-95 | AutoTheta-hi-95 | |
---|---|---|---|---|
unique_id | ||||
1 | 2016-09-01 | 111.075912 | 90.148819 | 135.999680 |
1 | 2016-10-01 | 129.111282 | 94.811134 | 160.372803 |
1 | 2016-11-01 | 131.296082 | 90.598457 | 168.251602 |
... | ... | ... | ... | ... |
1 | 2017-06-01 | 101.125748 | 41.213715 | 159.133316 |
1 | 2017-07-01 | 99.870514 | 35.173969 | 152.843002 |
1 | 2017-08-01 | 106.021683 | 38.784256 | 166.021072 |
12 rows × 4 columns
=Y_hat.reset_index()
Y_hat Y_hat
unique_id | ds | AutoTheta | |
---|---|---|---|
0 | 1 | 2016-09-01 | 111.075912 |
1 | 1 | 2016-10-01 | 129.111282 |
2 | 1 | 2016-11-01 | 131.296082 |
... | ... | ... | ... |
9 | 1 | 2017-06-01 | 101.125748 |
10 | 1 | 2017-07-01 | 99.870514 |
11 | 1 | 2017-08-01 | 106.021683 |
12 rows × 3 columns
# 将预测结果与真实值合并
'unique_id'] = test['unique_id'].astype(int)
test[= test.merge(Y_hat, how='left', on=['unique_id', 'ds'])
Y_hat1 Y_hat1
ds | y | unique_id | AutoTheta | |
---|---|---|---|---|
0 | 2016-09-01 | 109.3191 | 1 | 111.075912 |
1 | 2016-10-01 | 119.0502 | 1 | 129.111282 |
2 | 2016-11-01 | 116.8431 | 1 | 131.296082 |
... | ... | ... | ... | ... |
9 | 2017-06-01 | 104.2022 | 1 | 101.125748 |
10 | 2017-07-01 | 102.5861 | 1 | 99.870514 |
11 | 2017-08-01 | 114.0613 | 1 | 106.021683 |
12 rows × 4 columns
= plt.subplots(1, 1)
fig, ax = pd.concat([train, Y_hat1]).set_index('ds')
plot_df 'y', "AutoTheta"]].plot(ax=ax, linewidth=2)
plot_df[[' Forecast', fontsize=22)
ax.set_title('Year ', fontsize=20)
ax.set_ylabel('Timestamp [t]', fontsize=20)
ax.set_xlabel(={'size': 15})
ax.legend(propTrue) ax.grid(
预测方法与置信区间
要生成预测,请使用predict方法。
predict方法接受两个参数:预测接下来的h
(代表时间跨度)和level
。
h (int):
表示预测未来h步。在本例中为12个月。level (list of floats):
该可选参数用于概率预测。设置您的预测区间的水平(或置信百分位)。例如,level=[95]
意味着模型期望真实值在该区间内的概率为95%。
此处的预测对象是一个新的数据框,包含一个模型名称的列和y hat值,以及不确定性区间的列。
此步骤应耗时少于1秒。
=horizon) sf.predict(h
ds | AutoTheta | |
---|---|---|
unique_id | ||
1 | 2016-09-01 | 111.075912 |
1 | 2016-10-01 | 129.111282 |
1 | 2016-11-01 | 131.296082 |
... | ... | ... |
1 | 2017-06-01 | 101.125748 |
1 | 2017-07-01 | 99.870514 |
1 | 2017-08-01 | 106.021683 |
12 rows × 2 columns
= sf.predict(h=horizon, level=[95])
forecast_df
forecast_df
ds | AutoTheta | AutoTheta-lo-95 | AutoTheta-hi-95 | |
---|---|---|---|---|
unique_id | ||||
1 | 2016-09-01 | 111.075912 | 90.148819 | 135.999680 |
1 | 2016-10-01 | 129.111282 | 94.811134 | 160.372803 |
1 | 2016-11-01 | 131.296082 | 90.598457 | 168.251602 |
... | ... | ... | ... | ... |
1 | 2017-06-01 | 101.125748 | 41.213715 | 159.133316 |
1 | 2017-07-01 | 99.870514 | 35.173969 | 152.843002 |
1 | 2017-08-01 | 106.021683 | 38.784256 | 166.021072 |
12 rows × 4 columns
我们可以使用 pandas 函数 pd.concat()
将预测结果与历史数据连接起来,然后可以用这个结果进行绘图。
'ds') pd.concat([df, forecast_df]).set_index(
y | unique_id | AutoTheta | AutoTheta-lo-95 | AutoTheta-hi-95 | |
---|---|---|---|---|---|
ds | |||||
1972-01-01 | 85.6945 | 1 | NaN | NaN | NaN |
1972-02-01 | 71.8200 | 1 | NaN | NaN | NaN |
1972-03-01 | 66.0229 | 1 | NaN | NaN | NaN |
... | ... | ... | ... | ... | ... |
2017-06-01 | NaN | NaN | 101.125748 | 41.213715 | 159.133316 |
2017-07-01 | NaN | NaN | 99.870514 | 35.173969 | 152.843002 |
2017-08-01 | NaN | NaN | 106.021683 | 38.784256 | 166.021072 |
560 rows × 5 columns
现在让我们可视化我们的预测结果和时间序列的历史数据,同时绘制我们在进行95%置信度预测时获得的置信区间。
def plot_forecasts(y_hist, y_true, y_pred, models):
= plt.subplots(1, 1, figsize = (20, 7))
_, ax = y_true.merge(y_pred, how='left', on=['unique_id', 'ds'])
y_true = pd.concat([y_hist, y_true]).set_index('ds').tail(12*10)
df_plot 'y'] + models].plot(ax=ax, linewidth=2)
df_plot[[= ['orange', 'black', 'green']
colors for model, color in zip(models, colors):
ax.fill_between(df_plot.index, f'{model}-lo-95'],
df_plot[f'{model}-hi-95'],
df_plot[=.35,
alpha=color,
color=f'{model}-level-95')
label'', fontsize=22)
ax.set_title('', fontsize=20)
ax.set_ylabel('Timestamp [t]', fontsize=20)
ax.set_xlabel(={'size': 15})
ax.legend(propTrue) ax.grid(
=['AutoTheta']) plot_forecasts(train, test, forecast_df, models
让我们使用Statsforecast
中自带的plot函数绘制相同的图形,如下所示。
=[95]) sf.plot(df, forecast_df, level
交叉验证
在之前的步骤中,我们利用历史数据来预测未来。然而,为了评估其准确性,我们还希望了解模型在过去的表现。为了评估模型在数据上的准确性和稳健性,可以进行交叉验证。
对于时间序列数据,交叉验证是通过在历史数据上定义一个滑动窗口并预测随后的时期来完成的。这种交叉验证的形式使我们能够更好地估计模型在更广泛时间点的预测能力,同时保持训练集中的数据连续,正如我们的模型所要求的那样。
下图展示了这种交叉验证策略:
进行时间序列交叉验证
时间序列模型的交叉验证被视为最佳实践,但大多数实现速度都很慢。statsforecast库将交叉验证实现为分布式操作,使得过程变得不那么耗时。如果您拥有大数据集,您还可以使用Ray、Dask或Spark在分布式集群中执行交叉验证。
在这种情况下,我们想要评估每个模型在过去5个月的表现(n_windows=5)
,每隔两个月预测一次(step_size=12)
。根据您的计算机,这一步应该大约需要1分钟。
StatsForecast类中的cross_validation方法接受以下参数。
df:
训练数据框h (int):
表示要预测的未来h步。在这种情况下,是12个月。step_size (int):
每个窗口之间的步长。换句话说:您希望多久运行一次预测过程。n_windows(int):
用于交叉验证的窗口数量。换句话说:您希望评估过去多少个预测过程。
= sf.cross_validation(df=train,
crossvalidation_df =horizon,
h=12,
step_size=5) n_windows
crossvaldation_df对象是一个新的数据框,包含以下列:
unique_id:
索引。如果您不喜欢使用索引,只需运行crossvalidation_df.resetindex()ds:
日期戳或时间索引cutoff:
n_windows的最后一个日期戳或时间索引。y:
真实值"model":
包含模型名称和拟合值的列。
crossvalidation_df
ds | cutoff | y | AutoTheta | |
---|---|---|---|---|
unique_id | ||||
1 | 2011-09-01 | 2011-08-01 | 93.906197 | 98.167465 |
1 | 2011-10-01 | 2011-08-01 | 116.763397 | 116.969933 |
1 | 2011-11-01 | 2011-08-01 | 116.825798 | 119.135147 |
... | ... | ... | ... | ... |
1 | 2016-06-01 | 2015-08-01 | 102.404404 | 109.600456 |
1 | 2016-07-01 | 2015-08-01 | 102.951202 | 108.260147 |
1 | 2016-08-01 | 2015-08-01 | 104.697701 | 114.248260 |
60 rows × 4 columns
模型评估
我们现在可以使用合适的准确性度量来计算预测的准确性。在这里,我们将使用均方根误差(RMSE)。为此,我们首先需要 install datasetsforecast
,这是一个由 Nixtla 开发的 Python 库,其中包含计算 RMSE 的函数。
%%capture
!pip install datasetsforecast
from datasetsforecast.losses import rmse
计算RMSE的函数有两个参数:
- 实际值。
- 预测值,在本例中为AutoTheta。
= rmse(crossvalidation_df['y'], crossvalidation_df["AutoTheta"])
rmse print("RMSE using cross-validation: ", rmse)
RMSE using cross-validation: 6.9269824
正如您所注意到的,我们使用交叉验证结果来评估我们的模型。
现在我们将使用预测结果来评估我们的模型,我们将使用不同类型的指标 MAE, MAPE, MASE, RMSE, SMAPE
来评估 准确性
。
from datasetsforecast.losses import (mae, mape, mase, rmse, smape)
def evaluate_performace(y_hist, y_true, y_pred, model):
= y_true.merge(y_pred, how='left', on=['unique_id', 'ds'])
y_true = {}
evaluation = {}
evaluation[model] for metric in [mase, mae, mape, rmse, smape]:
= metric.__name__
metric_name if metric_name == 'mase':
= metric(y_true['y'].values,
evaluation[model][metric_name]
y_true[model].values, 'y'].values, seasonality=12)
y_hist[else:
= metric(y_true['y'].values, y_true[model].values)
evaluation[model][metric_name] return pd.DataFrame(evaluation).T
="AutoTheta") evaluate_performace(train, test, Y_hat, model
mae | mape | mase | rmse | smape | |
---|---|---|---|---|---|
AutoTheta | 6.281525 | 5.568355 | 1.212475 | 7.683672 | 5.479727 |
鸣谢
我们要感谢Naren Castellon编写本教程。
参考文献
Give us a ⭐ on Github