创建消息传递网络

将卷积算子推广到不规则域通常表示为邻域聚合消息传递方案。 使用\(\mathbf{x}^{(k-1)}_i \in \mathbb{R}^F\)表示第\((k-1)\)层中节点\(i\)的节点特征,以及\(\mathbf{e}_{j,i} \in \mathbb{R}^D\)表示从节点\(j\)到节点\(i\)的(可选的)边特征,消息传递图神经网络可以描述为

\[\mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \bigoplus_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right),\]

其中 \(\bigoplus\) 表示一个可微的、排列不变的函数,例如,求和、平均值或最大值,而 \(\gamma\)\(\phi\) 表示可微函数,例如 MLPs(多层感知器)。

“MessagePassing” 基类

提供了 MessagePassing 基类,它通过自动处理消息传播来帮助创建此类消息传递图神经网络。 用户只需定义函数 \(\phi\) ,即 message(),和 \(\gamma\) ,即 update(),以及要使用的聚合方案,即 aggr="add"aggr="mean"aggr="max"

这是通过以下方法完成的:

  • MessagePassing(aggr="add", flow="source_to_target", node_dim=-2): 定义要使用的聚合方案("add", "mean""max")以及消息传递的流向("source_to_target""target_to_source")。 此外,node_dim 属性指示沿哪个轴传播。

  • MessagePassing.propagate(edge_index, size=None, **kwargs): 初始调用以开始传播消息。 接收边索引和所有需要的数据,这些数据用于构造消息并更新节点嵌入。 请注意,propagate() 不仅限于在形状为 [N, N] 的方形邻接矩阵中交换消息,还可以在一般的稀疏分配矩阵中交换消息,例如,形状为 [N, M] 的二部图,通过传递 size=(N, M) 作为附加参数。 如果设置为 None,则假定分配矩阵为方形矩阵。 对于具有两组独立节点和索引的二部图,每组节点持有自己的信息,可以通过将信息作为元组传递来标记这种分割,例如 x=(x_N, x_M)

  • MessagePassing.message(...): 为节点 \(i\) 构造消息,类似于 \(\phi\),针对每条边 \((j,i) \in \mathcal{E}\)(如果 flow="source_to_target")或 \((i,j) \in \mathcal{E}\)(如果 flow="target_to_source")。 可以接受最初传递给 propagate() 的任何参数。 此外,传递给 propagate() 的张量可以通过在变量名后附加 _i_j 来映射到相应的节点 \(i\)\(j\)例如 x_ix_j。 请注意,我们通常将 \(i\) 称为聚合信息的中心节点,将 \(j\) 称为邻居节点,因为这是最常见的表示法。

  • MessagePassing.update(aggr_out, ...): 更新节点嵌入,类似于每个节点 \(i \in \mathcal{V}\)\(\gamma\)。 将聚合的输出作为第一个参数,并将最初传递给 propagate() 的任何参数作为其他参数。

让我们通过重新实现两种流行的GNN变体来验证这一点,即Kipf和Welling的GCN层Wang等人的EdgeConv层

实现GCN层

GCN层在数学上定义为

\[\mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{W}^{\top} \cdot \mathbf{x}_j^{(k-1)} \right) + \mathbf{b},\]

其中相邻节点的特征首先通过权重矩阵 \(\mathbf{W}\) 进行转换,然后通过它们的度数进行归一化,最后进行求和。 最后,我们将偏置向量 \(\mathbf{b}\) 应用于聚合输出。 这个公式可以分为以下步骤:

  1. 向邻接矩阵添加自环。

  2. 线性变换节点特征矩阵。

  3. 计算归一化系数。

  4. \(\phi\)中标准化节点特征。

  5. 汇总相邻节点的特征("add" 聚合)。

  6. 应用最终的偏置向量。

步骤1-3通常在消息传递发生之前计算。 步骤4-5可以使用MessagePassing基类轻松处理。 完整的层实现如下所示:

import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.lin = Linear(in_channels, out_channels, bias=False)
        self.bias = Parameter(torch.empty(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        out = self.propagate(edge_index, x=x, norm=norm)

        # Step 6: Apply a final bias vector.
        out = out + self.bias

        return out

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]

        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j

GCNConv 继承自 MessagePassing,并使用 "add" 传播。 该层的所有逻辑都在其 forward() 方法中实现。 在这里,我们首先使用 torch_geometric.utils.add_self_loops() 函数为边索引添加自环(步骤1),并通过调用 torch.nn.Linear 实例线性变换节点特征(步骤2)。

归一化系数由每个节点 \(i\) 的节点度 \(\deg(i)\) 得出,对于每条边 \((j,i) \in \mathcal{E}\),它被转换为 \(1/(\sqrt{\deg(i)} \cdot \sqrt{\deg(j)})\)。 结果保存在形状为 [num_edges, ] 的张量 norm 中(步骤3)。

然后我们调用propagate(),它在内部调用message()aggregate()update()。 我们将节点嵌入x和归一化系数norm作为消息传播的额外参数传递。

message()函数中,我们需要通过norm来归一化相邻节点的特征x_j。 这里,x_j表示一个提升的张量,它包含每条边的源节点特征,每个节点的邻居。 通过将_i_j附加到变量名,可以自动提升节点特征。 事实上,任何张量都可以通过这种方式转换,只要它们包含源节点或目标节点的特征。

这就是创建一个简单的消息传递层所需的全部内容。 你可以将这个层作为深度架构的构建块。 初始化和调用它是直接的:

conv = GCNConv(16, 32)
x = conv(x, edge_index)

实现边缘卷积

边缘卷积层处理图或点云,其数学定义为

\[\mathbf{x}_i^{(k)} = \max_{j \in \mathcal{N}(i)} h_{\mathbf{\Theta}} \left( \mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)} - \mathbf{x}_i^{(k-1)} \right),\]

其中 \(h_{\mathbf{\Theta}}\) 表示一个MLP。 与GCN层类似,我们可以使用 MessagePassing 类来实现这一层,这次使用 "max" 聚合:

import torch
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassing

class EdgeConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='max') #  "Max" aggregation.
        self.mlp = Seq(Linear(2 * in_channels, out_channels),
                       ReLU(),
                       Linear(out_channels, out_channels))

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        return self.propagate(edge_index, x=x)

    def message(self, x_i, x_j):
        # x_i has shape [E, in_channels]
        # x_j has shape [E, in_channels]

        tmp = torch.cat([x_i, x_j - x_i], dim=1)  # tmp has shape [E, 2 * in_channels]
        return self.mlp(tmp)

message() 函数内部,我们使用 self.mlp 来转换目标节点特征 x_i 和相对源节点特征 x_j - x_i 对于每条边 \((j,i) \in \mathcal{E}\)

边缘卷积实际上是一种动态卷积,它使用特征空间中的最近邻为每一层重新计算图。 幸运的是, 提供了一个名为 torch_geometric.nn.pool.knn_graph() 的GPU加速的批量k-NN图生成方法:

from torch_geometric.nn import knn_graph

class DynamicEdgeConv(EdgeConv):
    def __init__(self, in_channels, out_channels, k=6):
        super().__init__(in_channels, out_channels)
        self.k = k

    def forward(self, x, batch=None):
        edge_index = knn_graph(x, self.k, batch, loop=False, flow=self.flow)
        return super().forward(x, edge_index)

在这里,knn_graph() 计算了一个最近邻图,该图进一步用于调用 EdgeConvforward() 方法。

这为我们提供了一个干净的接口来初始化和调用这一层:

conv = DynamicEdgeConv(3, 128, k=6)
x = conv(x, batch)

Exercises

假设我们得到了以下Data对象:

import torch
from torch_geometric.data import Data

edge_index = torch.tensor([[0, 1],
                           [1, 0],
                           [1, 2],
                           [2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index.t().contiguous())

尝试回答以下与GCNConv相关的问题:

  1. rowcol 包含什么信息?

  2. degree() 是做什么的?

  3. 为什么我们使用degree(col, ...)而不是degree(row, ...)

  4. deg_inv_sqrt[col]deg_inv_sqrt[row] 是做什么的?

  5. x_jmessage() 函数中包含了什么信息?如果 self.lin 表示恒等函数,那么 x_j 的确切内容是什么?

  6. 添加一个update()函数到GCNConv,该函数将转换后的中心节点特征添加到聚合输出中。

尝试回答以下与EdgeConv相关的问题:

  1. 什么是 x_ix_j - x_i

  2. torch.cat([x_i, x_j - x_i], dim=1) 是做什么的?为什么 dim = 1