torch_geometric.explain.algorithm.PGExplainer
- class PGExplainer(epochs: int, lr: float = 0.003, **kwargs)[source]
Bases:
ExplainerAlgorithmPGExplainer模型来自“Parameterized Explainer for Graph Neural Network”论文。
在内部,它利用神经网络来识别在GNN预测中起关键作用的子图结构。重要的是,
PGExplainer需要通过train()进行训练,然后才能生成解释:explainer = Explainer( model=model, algorithm=PGExplainer(epochs=30, lr=0.003), explanation_type='phenomenon', edge_mask_type='object', model_config=ModelConfig(...), ) # Train against a variety of node-level or graph-level predictions: for epoch in range(30): for index in [...]: # Indices to train against. loss = explainer.algorithm.train(epoch, model, x, edge_index, target=target, index=index) # Get the final explanations: explanation = explainer(x, edge_index, target=target, index=0)
- Parameters:
- train(epoch: int, model: Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs)[source]
训练底层解释器模型。 在能够进行预测之前需要调用。
- Parameters:
epoch (int) – 训练阶段的当前周期。
model (torch.nn.Module) – The model to explain.
x (torch.Tensor) – 同构图中的输入节点特征。
edge_index (torch.Tensor) – 同质图的输入边索引。
target (torch.Tensor) – The target of the model.
index (int 或 torch.Tensor, 可选) – 要解释的模型输出的索引。需要是单个索引。 (默认:
None)**kwargs (optional) – Additional keyword arguments passed to
model.
- forward(model: Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs) Explanation[source]
计算解释。
- Parameters:
model (torch.nn.Module) – The model to explain.
x (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]) – The input node features of a homogeneous or heterogeneous graph.
edge_index (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]) – The input edge indices of a homogeneous or heterogeneous graph.
target (torch.Tensor) – The target of the model.
index (Union[int, Tensor], optional) – The index of the model output to explain. Can be a single index or a tensor of indices. (default:
None)**kwargs (optional) – Additional keyword arguments passed to
model.
- Return type: