torch_geometric.nn.models.GraphUNet

class GraphUNet(in_channels: int, hidden_channels: int, out_channels: int, depth: int, pool_ratios: Union[float, List[float]] = 0.5, sum_res: bool = True, act: Union[str, Callable] = 'relu')[source]

Bases: Module

来自“Graph U-Nets”论文的Graph U-Net模型,该模型实现了类似U-Net的架构,包含图池化和反池化操作。

Parameters:
  • in_channels (int) – Size of each input sample.

  • hidden_channels (int) – Size of each hidden sample.

  • out_channels (int) – Size of each output sample.

  • depth (int) – U-Net架构的深度。

  • pool_ratios (float[float], 可选) – 每个深度的图池化比例。(默认值:0.5

  • sum_res (bool, 可选) – 如果设置为 False,将使用连接而不是求和来集成跳跃连接。(默认值:True

  • act (torch.nn.functional, 可选) – 要使用的非线性函数。 (默认: torch.nn.functional.relu)

forward(x: Tensor, edge_index: Tensor, batch: Optional[Tensor] = None, edge_weight: Optional[Tensor] = None) Tensor[source]
Return type:

Tensor

reset_parameters()[source]

重置模块的所有可学习参数。