TFT 解释器 用于 Temporal Fusion Transformer (TFTModel)¶
TFTExplainer 使用一个训练好的 TFTModel
并从模型中提取解释性信息。
plot_variable_selection()
绘制每个输入特征的变量选择权重。 - 编码器重要性:目标的历史部分、过去的协变量和未来协变量的历史部分 - 解码器重要性:未来协变量的未来部分 - 静态协变量重要性:数值和分类静态协变量的重要性plot_attention()
绘制了 TFTModel 对给定的过去和未来输入应用的转换器注意力。注意力在所有注意力头之间进行了聚合。
可以使用 explain()
返回的 TFTExplainabilityResult
提取注意力和特征重要性值。方法描述中展示了一个示例。
我们还展示了如何在 TFTModel 的示例笔记本中使用 TFTExplainer,请参见 这里。
- class darts.explainability.tft_explainer.TFTExplainer(model, background_series=None, background_past_covariates=None, background_future_covariates=None)[源代码]¶
基类:
_ForecastingModelExplainer
TFTModel 的解释器类。
定义
背景系列是一个用于生成可解释性结果的 TimeSeries ,如果未向
explain()
传递 foreground ,则使用该默认值。前景系列是一个可以传递给
explain()
的 TimeSeries,用于生成可解释性结果,而不是使用背景。
- 参数
model (darts.models.forecasting.tft_model.TFTModel) – 要解释的拟合 TFTModel。
background_series (
Union
[TimeSeries
,Sequence
[TimeSeries
],None
]) – 可选地,一系列或一系列列表,用作解释的默认目标系列。如果 model 是在单个目标系列上训练的,则此项为可选。默认情况下,它是拟合时使用的 series。如果 model 是在多个(序列的)目标系列上训练的,则此项为必填。background_past_covariates (
Union
[TimeSeries
,Sequence
[TimeSeries
],None
]) – 可选地,一个过去的协变量序列或序列列表,用作解释的默认过去协变量序列。与 background_series 相同的要求适用。background_future_covariates (
Union
[TimeSeries
,Sequence
[TimeSeries
],None
]) – 可选地,一个未来协变量序列或作为解释默认未来协变量序列的序列列表。与 background_series 相同的要求适用。
实际案例
>>> from darts.datasets import AirPassengersDataset >>> from darts.explainability.tft_explainer import TFTExplainer >>> from darts.models import TFTModel >>> series = AirPassengersDataset().load() >>> model = TFTModel( >>> input_chunk_length=12, >>> output_chunk_length=6, >>> add_encoders={"cyclic": {"future": ["hour"]}} >>> ) >>> model.fit(series) >>> # create the explainer and generate explanations >>> explainer = TFTExplainer(model) >>> results = explainer.explain() >>> # plot the results >>> explainer.plot_attention(results, plot_type="all") >>> explainer.plot_variable_selection(results)
方法
explain
([foreground_series, ...])返回 foreground_series 中所有序列的
TFTExplainabilityResult
结果。plot_attention
(expl_result[, plot_type, ...])绘制 TFTModel 的注意力头。
plot_variable_selection
(expl_result[, ...])根据输入绘制 TFTModel 的变量选择 / 特征重要性。
- explain(foreground_series=None, foreground_past_covariates=None, foreground_future_covariates=None, horizons=None, target_components=None)[源代码]¶
返回 foreground_series 中所有序列的
TFTExplainabilityResult
结果。如果 foreground_series 为 None,将使用 TFTExplainer 创建时的 background 输入(无论是创建时传递的 background,还是仅在单个序列上训练的 TFTModel 中存储的序列)。对于每个序列,结果包含注意力头、编码器变量重要性、解码器变量重要性和静态协变量重要性。- 参数
foreground_series (
Union
[TimeSeries
,Sequence
[TimeSeries
],None
]) – 可选地,一个或一系列目标 TimeSeries 需要解释。可以是多元的。如果没有提供,将解释背景 TimeSeries 。foreground_past_covariates (
Union
[TimeSeries
,Sequence
[TimeSeries
],None
]) – 可选地,如果预测模型需要,可以是一个或一系列过去协变量 TimeSeries。foreground_future_covariates (
Union
[TimeSeries
,Sequence
[TimeSeries
],None
]) – 可选地,如果预测模型需要,可以是一个或一系列未来协变量 TimeSeries。horizons (
Optional
[Sequence
[int
],None
]) – 此参数不被 TFTExplainer 使用。target_components (
Optional
[Sequence
[str
],None
]) – 此参数不被 TFTExplainer 使用。
- 返回
包含注意力头、编码器变量重要性、解码器变量重要性和静态协变量重要性的可解释性结果。
- 返回类型
实际案例
>>> explainer = TFTExplainer(model) # requires `background` if model was trained on multiple series
可选地,提供一个前景输入以在新输入上生成解释。否则,留空以从 TFTExplainer 创建中计算背景解释。
>>> explain_results = explainer.explain( >>> foreground_series=foreground_series, >>> foreground_past_covariates=foreground_past_covariates, >>> foreground_future_covariates=foreground_future_covariates, >>> ) >>> attn = explain_results.get_attention() >>> importances = explain_results.get_feature_importances()
- plot_attention(expl_result, plot_type='all', show_index_as='relative', ax=None, max_nr_series=5, show_plot=True)[源代码]¶
绘制 TFTModel 的注意力头。
- 参数
expl_result (
TFTExplainabilityResult
) – 一个 TFTExplainabilityResult 对象。对应于explain()
的输出。plot_type (
Optional
[Literal
[‘all’, ‘time’, ‘heatmap’],None
]) – 注意力头图的类型。可以是以下之一:(“all”, “time”, “heatmap”)。如果选择 “all”,将绘制每个时间步的注意力(根据 TFTExplainabilityResult 中的时间步)。最大时间步对应于训练的 TFTModel 的 output_chunk_length。如果选择 “time”,将绘制所有时间步的平均注意力。如果选择 “heatmap”,将在热图中绘制每个时间步的注意力。时间步显示在 y 轴上,时间/相对索引显示在 x 轴上。show_index_as (
Literal
[‘relative’, ‘time’]) – 要显示的索引类型。可以是以下之一:(“relative”, “time”)。如果是 “relative”,将绘制 x 轴从 (-input_chunk_length, output_chunk_length - 1)。0 对应于第一个预测点。如果是 “time”,将绘制 x 轴与相应的 TFTExplainabilityResult 的实际时间索引(或范围索引)。ax (
Optional
[Axes
,None
]) – 可选地,绘图的轴。仅对单个 expl_result 有效。max_nr_series (
int
) – 在 expl_result 计算了多个系列的情况下,显示的最大图表数量。show_plot (
bool
) – 是否显示图表。
- 返回类型
Axes
- plot_variable_selection(expl_result, fig_size=None, max_nr_series=5)[源代码]¶
根据输入绘制 TFTModel 的变量选择/特征重要性。该图包含三个子图:
编码器重要性:包含过去目标、过去协变量和历史未来协变量对编码器(输入块)的重要性
解码器重要性: 包含未来协变量对解码器(输出块)的重要性
静态协变量重要性:包含数值和/或分类静态协变量的重要性
- 参数
expl_result (
TFTExplainabilityResult
) – 一个 TFTExplainabilityResult 对象。对应于explain()
的输出。fig_size – 要绘制的图形的大小。
max_nr_series (
int
) – 在 expl_result 计算了多个系列的情况下,显示的最大图表数量。