torch_geometric.explain.algorithm.PGExplainer

class PGExplainer(epochs: int, lr: float = 0.003, **kwargs)[source]

Bases: ExplainerAlgorithm

PGExplainer模型来自“Parameterized Explainer for Graph Neural Network”论文。

在内部,它利用神经网络来识别在GNN预测中起关键作用的子图结构。重要的是,PGExplainer需要通过train()进行训练,然后才能生成解释:

explainer = Explainer(
    model=model,
    algorithm=PGExplainer(epochs=30, lr=0.003),
    explanation_type='phenomenon',
    edge_mask_type='object',
    model_config=ModelConfig(...),
)

# Train against a variety of node-level or graph-level predictions:
for epoch in range(30):
    for index in [...]:  # Indices to train against.
        loss = explainer.algorithm.train(epoch, model, x, edge_index,
                                         target=target, index=index)

# Get the final explanations:
explanation = explainer(x, edge_index, target=target, index=0)
Parameters:
  • epochs (int) – 训练的轮数。

  • lr (float, optional) – 应用的学习率。 (默认: 0.003).

  • **kwargs (可选) – 用于覆盖默认设置的其他超参数 在 coeffs中。

reset_parameters()[source]

重置模块的所有可学习参数。

train(epoch: int, model: Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs)[source]

训练底层解释器模型。 在能够进行预测之前需要调用。

Parameters:
  • epoch (int) – 训练阶段的当前周期。

  • model (torch.nn.Module) – The model to explain.

  • x (torch.Tensor) – 同构图中的输入节点特征。

  • edge_index (torch.Tensor) – 同质图的输入边索引。

  • target (torch.Tensor) – The target of the model.

  • index (inttorch.Tensor, 可选) – 要解释的模型输出的索引。需要是单个索引。 (默认: None)

  • **kwargs (optional) – Additional keyword arguments passed to model.

forward(model: Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs) Explanation[source]

计算解释。

Parameters:
  • model (torch.nn.Module) – The model to explain.

  • x (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]) – The input node features of a homogeneous or heterogeneous graph.

  • edge_index (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]) – The input edge indices of a homogeneous or heterogeneous graph.

  • target (torch.Tensor) – The target of the model.

  • index (Union[int, Tensor], optional) – The index of the model output to explain. Can be a single index or a tensor of indices. (default: None)

  • **kwargs (optional) – Additional keyword arguments passed to model.

Return type:

Explanation

supports() bool[source]

Checks if the explainer supports the user-defined settings provided in self.explainer_config, self.model_config.

Return type:

bool 翻译后的内容: bool 在这个例子中,`bool` 是一个Python函数名称,根据翻译规则1,不需要翻译。因此,翻译后的内容保持不变。