TFT 解释器 用于 Temporal Fusion Transformer (TFTModel)

TFTExplainer 使用一个训练好的 TFTModel 并从模型中提取解释性信息。

  • plot_variable_selection() 绘制每个输入特征的变量选择权重。 - 编码器重要性:目标的历史部分、过去的协变量和未来协变量的历史部分 - 解码器重要性:未来协变量的未来部分 - 静态协变量重要性:数值和分类静态协变量的重要性

  • plot_attention() 绘制了 TFTModel 对给定的过去和未来输入应用的转换器注意力。注意力在所有注意力头之间进行了聚合。

可以使用 explain() 返回的 TFTExplainabilityResult 提取注意力和特征重要性值。方法描述中展示了一个示例。

我们还展示了如何在 TFTModel 的示例笔记本中使用 TFTExplainer,请参见 这里

class darts.explainability.tft_explainer.TFTExplainer(model, background_series=None, background_past_covariates=None, background_future_covariates=None)[源代码]

基类:_ForecastingModelExplainer

TFTModel 的解释器类。

定义

  • 背景系列是一个用于生成可解释性结果的 TimeSeries ,如果未向 explain() 传递 foreground ,则使用该默认值。

  • 前景系列是一个可以传递给 explain()TimeSeries,用于生成可解释性结果,而不是使用背景。

参数
  • model (darts.models.forecasting.tft_model.TFTModel) – 要解释的拟合 TFTModel

  • background_series (Union[TimeSeries, Sequence[TimeSeries], None]) – 可选地,一系列或一系列列表,用作解释的默认目标系列。如果 model 是在单个目标系列上训练的,则此项为可选。默认情况下,它是拟合时使用的 series。如果 model 是在多个(序列的)目标系列上训练的,则此项为必填。

  • background_past_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – 可选地,一个过去的协变量序列或序列列表,用作解释的默认过去协变量序列。与 background_series 相同的要求适用。

  • background_future_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – 可选地,一个未来协变量序列或作为解释默认未来协变量序列的序列列表。与 background_series 相同的要求适用。

实际案例

>>> from darts.datasets import AirPassengersDataset
>>> from darts.explainability.tft_explainer import TFTExplainer
>>> from darts.models import TFTModel
>>> series = AirPassengersDataset().load()
>>> model = TFTModel(
>>>     input_chunk_length=12,
>>>     output_chunk_length=6,
>>>     add_encoders={"cyclic": {"future": ["hour"]}}
>>> )
>>> model.fit(series)
>>> # create the explainer and generate explanations
>>> explainer = TFTExplainer(model)
>>> results = explainer.explain()
>>> # plot the results
>>> explainer.plot_attention(results, plot_type="all")
>>> explainer.plot_variable_selection(results)

方法

explain([foreground_series, ...])

返回 foreground_series 中所有序列的 TFTExplainabilityResult 结果。

plot_attention(expl_result[, plot_type, ...])

绘制 TFTModel 的注意力头。

plot_variable_selection(expl_result[, ...])

根据输入绘制 TFTModel 的变量选择 / 特征重要性。

explain(foreground_series=None, foreground_past_covariates=None, foreground_future_covariates=None, horizons=None, target_components=None)[源代码]

返回 foreground_series 中所有序列的 TFTExplainabilityResult 结果。如果 foreground_seriesNone,将使用 TFTExplainer 创建时的 background 输入(无论是创建时传递的 background,还是仅在单个序列上训练的 TFTModel 中存储的序列)。对于每个序列,结果包含注意力头、编码器变量重要性、解码器变量重要性和静态协变量重要性。

参数
  • foreground_series (Union[TimeSeries, Sequence[TimeSeries], None]) – 可选地,一个或一系列目标 TimeSeries 需要解释。可以是多元的。如果没有提供,将解释背景 TimeSeries

  • foreground_past_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – 可选地,如果预测模型需要,可以是一个或一系列过去协变量 TimeSeries

  • foreground_future_covariates (Union[TimeSeries, Sequence[TimeSeries], None]) – 可选地,如果预测模型需要,可以是一个或一系列未来协变量 TimeSeries

  • horizons (Optional[Sequence[int], None]) – 此参数不被 TFTExplainer 使用。

  • target_components (Optional[Sequence[str], None]) – 此参数不被 TFTExplainer 使用。

返回

包含注意力头、编码器变量重要性、解码器变量重要性和静态协变量重要性的可解释性结果。

返回类型

TFTExplainabilityResult

实际案例

>>> explainer = TFTExplainer(model)  # requires `background` if model was trained on multiple series

可选地,提供一个前景输入以在新输入上生成解释。否则,留空以从 TFTExplainer 创建中计算背景解释。

>>> explain_results = explainer.explain(
>>>     foreground_series=foreground_series,
>>>     foreground_past_covariates=foreground_past_covariates,
>>>     foreground_future_covariates=foreground_future_covariates,
>>> )
>>> attn = explain_results.get_attention()
>>> importances = explain_results.get_feature_importances()
model: TFTModel
plot_attention(expl_result, plot_type='all', show_index_as='relative', ax=None, max_nr_series=5, show_plot=True)[源代码]

绘制 TFTModel 的注意力头。

参数
  • expl_result (TFTExplainabilityResult) – 一个 TFTExplainabilityResult 对象。对应于 explain() 的输出。

  • plot_type (Optional[Literal[‘all’, ‘time’, ‘heatmap’], None]) – 注意力头图的类型。可以是以下之一:(“all”, “time”, “heatmap”)。如果选择 “all”,将绘制每个时间步的注意力(根据 TFTExplainabilityResult 中的时间步)。最大时间步对应于训练的 TFTModeloutput_chunk_length。如果选择 “time”,将绘制所有时间步的平均注意力。如果选择 “heatmap”,将在热图中绘制每个时间步的注意力。时间步显示在 y 轴上,时间/相对索引显示在 x 轴上。

  • show_index_as (Literal[‘relative’, ‘time’]) – 要显示的索引类型。可以是以下之一:(“relative”, “time”)。如果是 “relative”,将绘制 x 轴从 (-input_chunk_length, output_chunk_length - 1)0 对应于第一个预测点。如果是 “time”,将绘制 x 轴与相应的 TFTExplainabilityResult 的实际时间索引(或范围索引)。

  • ax (Optional[Axes, None]) – 可选地,绘图的轴。仅对单个 expl_result 有效。

  • max_nr_series (int) – 在 expl_result 计算了多个系列的情况下,显示的最大图表数量。

  • show_plot (bool) – 是否显示图表。

返回类型

Axes

plot_variable_selection(expl_result, fig_size=None, max_nr_series=5)[源代码]

根据输入绘制 TFTModel 的变量选择/特征重要性。该图包含三个子图:

  • 编码器重要性:包含过去目标、过去协变量和历史未来协变量对编码器(输入块)的重要性

  • 解码器重要性: 包含未来协变量对解码器(输出块)的重要性

  • 静态协变量重要性:包含数值和/或分类静态协变量的重要性

参数
  • expl_result (TFTExplainabilityResult) – 一个 TFTExplainabilityResult 对象。对应于 explain() 的输出。

  • fig_size – 要绘制的图形的大小。

  • max_nr_series (int) – 在 expl_result 计算了多个系列的情况下,显示的最大图表数量。