torch_geometric.nn.models.GraphUNet
- class GraphUNet(in_channels: int, hidden_channels: int, out_channels: int, depth: int, pool_ratios: Union[float, List[float]] = 0.5, sum_res: bool = True, act: Union[str, Callable] = 'relu')[source]
Bases:
Module来自“Graph U-Nets”论文的Graph U-Net模型,该模型实现了类似U-Net的架构,包含图池化和反池化操作。
- Parameters:
in_channels (int) – Size of each input sample.
hidden_channels (int) – Size of each hidden sample.
out_channels (int) – Size of each output sample.
depth (int) – U-Net架构的深度。
sum_res (bool, 可选) – 如果设置为
False,将使用连接而不是求和来集成跳跃连接。(默认值:True)act (torch.nn.functional, 可选) – 要使用的非线性函数。 (默认:
torch.nn.functional.relu)