作者: akensert
创建日期: 2021/09/13
最后修改: 2021/12/26
描述: 图注意力网络 (GAT) 的一个实现,用于节点分类。
图神经网络 是处理结构为图的数据(例如,社交网络或分子结构)时首选的神经网络架构,结果优于全连接网络或卷积网络。
在本教程中,我们将实现一种特定的图神经网络,称为 图注意力网络 (GAT),以预测科学论文的标签,基于引用它们的论文类型(使用Cora 数据集)。
有关 GAT 的更多信息,请参见原始论文 图注意力网络 以及 DGL 的图注意力网络 文档。
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import pandas as pd
import os
import warnings
warnings.filterwarnings("ignore")
pd.set_option("display.max_columns", 6)
pd.set_option("display.max_rows", 6)
np.random.seed(2)
Cora 数据集 的准备遵循了
使用图神经网络进行节点分类 教程。有关数据集和探索性数据分析的更多细节,请参阅该教程。
简而言之,Cora 数据集由两个文件组成:cora.cites
,其中包含论文之间的 有向链接 (引用);以及 cora.content
,其中包含相应论文的 特征 和七个标签之一(论文的 主题)。
zip_file = keras.utils.get_file(
fname="cora.tgz",
origin="https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz",
extract=True,
)
data_dir = os.path.join(os.path.dirname(zip_file), "cora")
citations = pd.read_csv(
os.path.join(data_dir, "cora.cites"),
sep="\t",
header=None,
names=["target", "source"],
)
papers = pd.read_csv(
os.path.join(data_dir, "cora.content"),
sep="\t",
header=None,
names=["paper_id"] + [f"term_{idx}" for idx in range(1433)] + ["subject"],
)
class_values = sorted(papers["subject"].unique())
class_idx = {name: id for id, name in enumerate(class_values)}
paper_idx = {name: idx for idx, name in enumerate(sorted(papers["paper_id"].unique()))}
papers["paper_id"] = papers["paper_id"].apply(lambda name: paper_idx[name])
citations["source"] = citations["source"].apply(lambda name: paper_idx[name])
citations["target"] = citations["target"].apply(lambda name: paper_idx[name])
papers["subject"] = papers["subject"].apply(lambda value: class_idx[value])
print(citations)
print(papers)
target source
0 0 21
1 0 905
2 0 906
... ... ...
5426 1874 2586
5427 1876 1874
5428 1897 2707
[5429 rows x 2 columns]
paper_id term_0 term_1 ... term_1431 term_1432 subject
0 462 0 0 ... 0 0 2
1 1911 0 0 ... 0 0 5
2 2002 0 0 ... 0 0 4
... ... ... ... ... ... ... ...
2705 2372 0 0 ... 0 0 1
2706 955 0 0 ... 0 0 0
2707 376 0 0 ... 0 0 2
[2708 rows x 1435 columns]
# 获取随机索引
random_indices = np.random.permutation(range(papers.shape[0]))
# 50/50 拆分
train_data = papers.iloc[random_indices[: len(random_indices) // 2]]
test_data = papers.iloc[random_indices[len(random_indices) // 2 :]]
# 获取将在训练模型时用于收集节点状态的论文索引
# 稍后从图中获取节点状态
train_indices = train_data["paper_id"].to_numpy()
test_indices = test_data["paper_id"].to_numpy()
# 获取每个 paper_id 对应的真实标签
train_labels = train_data["subject"].to_numpy()
test_labels = test_data["subject"].to_numpy()
# 定义图,即边张量和节点特征张量
edges = tf.convert_to_tensor(citations[["target", "source"]])
node_states = tf.convert_to_tensor(papers.sort_values("paper_id").iloc[:, 1:-1])
# 输出图的形状
print("Edges shape:\t\t", edges.shape)
print("Node features shape:", node_states.shape)
边的形状: (5429, 2)
节点特征的形状: (2708, 1433)
GAT 以图作为输入(即边张量和节点特征张量),并输出 [更新的] 节点状态。对于每个目标节点,节点状态是经过 N 路径的邻域聚合信息(其中 N 由 GAT 的层数决定)。重要的是,与 图卷积网络 (GCN) 相比,GAT 利用注意力机制从邻居节点(或 源节点)聚合信息。换句话说,GAT 不会简单地从源节点(源论文)向目标节点(目标论文)平均/求和节点状态,而是首先对每个源节点状态应用归一化的注意力分数,然后进行求和。
GAT 模型实现了多头图注意力层。MultiHeadGraphAttention
层只是多个图注意力层(GraphAttention
)的连接(或平均),每个注意力层都有独立的可学习权重 W
。GraphAttention
层完成以下操作:
考虑输入节点状态 h^{l}
,通过 W^{l}
线性变换,得到 z^{l}
。
对于每个目标节点:
j
计算成对的注意力分数 a^{l}^{T}(z^{l}_{i}||z^{l}_{j})
,得到 e_{ij}
(针对所有 j
)。
||
表示连接,_{i}
对应目标节点,_{j}
对应某个 1-hop 邻居/源节点。e_{ij}
归一化,使得指向目标节点的入边注意力分数之和 (sum_{k}{e_{norm}_{ik}}
) 等于 1。e_{norm}_{ij}
应用于 z_{j}
,并将其添加到新目标节点状态 h^{l+1}_{i}
,对于所有 j
。class GraphAttention(layers.Layer):
def __init__(
self,
units,
kernel_initializer="glorot_uniform",
kernel_regularizer=None,
**kwargs,
):
super().__init__(**kwargs)
self.units = units
self.kernel_initializer = keras.initializers.get(kernel_initializer)
self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)
def build(self, input_shape):
self.kernel = self.add_weight(
shape=(input_shape[0][-1], self.units),
trainable=True,
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
name="kernel",
)
self.kernel_attention = self.add_weight(
shape=(self.units * 2, 1),
trainable=True,
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
name="kernel_attention",
)
self.built = True
def call(self, inputs):
node_states, edges = inputs
# 线性变换节点状态
node_states_transformed = tf.matmul(node_states, self.kernel)
# (1) 计算成对的注意力分数
node_states_expanded = tf.gather(node_states_transformed, edges)
node_states_expanded = tf.reshape(
node_states_expanded, (tf.shape(edges)[0], -1)
)
attention_scores = tf.nn.leaky_relu(
tf.matmul(node_states_expanded, self.kernel_attention)
)
attention_scores = tf.squeeze(attention_scores, -1)
# (2) 归一化注意力分数
attention_scores = tf.math.exp(tf.clip_by_value(attention_scores, -2, 2))
attention_scores_sum = tf.math.unsorted_segment_sum(
data=attention_scores,
segment_ids=edges[:, 0],
num_segments=tf.reduce_max(edges[:, 0]) + 1,
)
attention_scores_sum = tf.repeat(
attention_scores_sum, tf.math.bincount(tf.cast(edges[:, 0], "int32"))
)
attention_scores_norm = attention_scores / attention_scores_sum
# (3) 收集邻居的节点状态,应用注意力分数并聚合
node_states_neighbors = tf.gather(node_states_transformed, edges[:, 1])
out = tf.math.unsorted_segment_sum(
data=node_states_neighbors * attention_scores_norm[:, tf.newaxis],
segment_ids=edges[:, 0],
num_segments=tf.shape(node_states)[0],
)
return out
class MultiHeadGraphAttention(layers.Layer):
def __init__(self, units, num_heads=8, merge_type="concat", **kwargs):
super().__init__(**kwargs)
self.num_heads = num_heads
self.merge_type = merge_type
self.attention_layers = [GraphAttention(units) for _ in range(num_heads)]
def call(self, inputs):
atom_features, pair_indices = inputs
# 从每个注意力头获得输出
outputs = [
attention_layer([atom_features, pair_indices])
for attention_layer in self.attention_layers
]
# 连接或平均每个头的节点状态
if self.merge_type == "concat":
outputs = tf.concat(outputs, axis=-1)
else:
outputs = tf.reduce_mean(tf.stack(outputs, axis=-1), axis=-1)
# 激活并返回节点状态
return tf.nn.relu(outputs)
train_step
、test_step
和 predict_step
方法的训练逻辑注意,GAT 模型在所有阶段(训练、验证和测试)都对整个图(即 node_states
和 edges
)进行操作。因此,node_states
和 edges
被传递给 keras.Model
的构造函数,并用作属性。不同阶段之间的区别在于索引(和标签),这些索引收集特定的输出(tf.gather(outputs, indices)
)。
class GraphAttentionNetwork(keras.Model):
def __init__(
self,
node_states,
edges,
hidden_units,
num_heads,
num_layers,
output_dim,
**kwargs,
):
super().__init__(**kwargs)
self.node_states = node_states
self.edges = edges
self.preprocess = layers.Dense(hidden_units * num_heads, activation="relu")
self.attention_layers = [
MultiHeadGraphAttention(hidden_units, num_heads) for _ in range(num_layers)
]
self.output_layer = layers.Dense(output_dim)
def call(self, inputs):
node_states, edges = inputs
x = self.preprocess(node_states)
for attention_layer in self.attention_layers:
x = attention_layer([x, edges]) + x
outputs = self.output_layer(x)
return outputs
def train_step(self, data):
indices, labels = data
with tf.GradientTape() as tape:
# 前向传播
outputs = self([self.node_states, self.edges])
# 计算损失
loss = self.compiled_loss(labels, tf.gather(outputs, indices))
# 计算梯度
grads = tape.gradient(loss, self.trainable_weights)
# 应用梯度(更新权重)
optimizer.apply_gradients(zip(grads, self.trainable_weights))
# 更新度量
self.compiled_metrics.update_state(labels, tf.gather(outputs, indices))
return {m.name: m.result() for m in self.metrics}
def predict_step(self, data):
indices = data
# 前向传播
outputs = self([self.node_states, self.edges])
# 计算概率
return tf.nn.softmax(tf.gather(outputs, indices))
def test_step(self, data):
indices, labels = data
# 前向传播
outputs = self([self.node_states, self.edges])
# 计算损失
loss = self.compiled_loss(labels, tf.gather(outputs, indices))
# 更新度量
self.compiled_metrics.update_state(labels, tf.gather(outputs, indices))
return {m.name: m.result() for m in self.metrics}
# 定义超参数
HIDDEN_UNITS = 100
NUM_HEADS = 8
NUM_LAYERS = 3
OUTPUT_DIM = len(class_values)
NUM_EPOCHS = 100
BATCH_SIZE = 256
VALIDATION_SPLIT = 0.1
LEARNING_RATE = 3e-1
MOMENTUM = 0.9
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = keras.optimizers.SGD(LEARNING_RATE, momentum=MOMENTUM)
accuracy_fn = keras.metrics.SparseCategoricalAccuracy(name="acc")
early_stopping = keras.callbacks.EarlyStopping(
monitor="val_acc", min_delta=1e-5, patience=5, restore_best_weights=True
)
# 构建模型
gat_model = GraphAttentionNetwork(
node_states, edges, HIDDEN_UNITS, NUM_HEADS, NUM_LAYERS, OUTPUT_DIM
)
# 编译模型
gat_model.compile(loss=loss_fn, optimizer=optimizer, metrics=[accuracy_fn])
gat_model.fit(
x=train_indices,
y=train_labels,
validation_split=VALIDATION_SPLIT,
batch_size=BATCH_SIZE,
epochs=NUM_EPOCHS,
callbacks=[early_stopping],
verbose=2,
)
_, test_accuracy = gat_model.evaluate(x=test_indices, y=test_labels, verbose=0)
print("--" * 38 + f"\n测试准确率 {test_accuracy*100:.1f}%")
Epoch 1/100
5/5 - 26s - loss: 1.8418 - acc: 0.2980 - val_loss: 1.5117 - val_acc: 0.4044 - 26s/epoch - 5s/step
Epoch 2/100
5/5 - 6s - loss: 1.2422 - acc: 0.5640 - val_loss: 1.0407 - val_acc: 0.6471 - 6s/epoch - 1s/step
Epoch 3/100
5/5 - 5s - loss: 0.7092 - acc: 0.7906 - val_loss: 0.8201 - val_acc: 0.7868 - 5s/epoch - 996ms/step
Epoch 4/100
5/5 - 5s - loss: 0.4768 - acc: 0.8604 - val_loss: 0.7451 - val_acc: 0.8088 - 5s/epoch - 934ms/step
Epoch 5/100
5/5 - 5s - loss: 0.2641 - acc: 0.9294 - val_loss: 0.7499 - val_acc: 0.8088 - 5s/epoch - 945ms/step
Epoch 6/100
5/5 - 5s - loss: 0.1487 - acc: 0.9663 - val_loss: 0.6803 - val_acc: 0.8382 - 5s/epoch - 967ms/step
Epoch 7/100
5/5 - 5s - loss: 0.0970 - acc: 0.9811 - val_loss: 0.6688 - val_acc: 0.8088 - 5s/epoch - 960ms/step
Epoch 8/100
5/5 - 5s - loss: 0.0597 - acc: 0.9934 - val_loss: 0.7295 - val_acc: 0.8162 - 5s/epoch - 981ms/step
Epoch 9/100
5/5 - 5s - loss: 0.0398 - acc: 0.9967 - val_loss: 0.7551 - val_acc: 0.8309 - 5s/epoch - 991ms/step
Epoch 10/100
5/5 - 5s - loss: 0.0312 - acc: 0.9984 - val_loss: 0.7666 - val_acc: 0.8309 - 5s/epoch - 987ms/step
Epoch 11/100
5/5 - 5s - loss: 0.0219 - acc: 0.9992 - val_loss: 0.7726 - val_acc: 0.8309 - 5s/epoch - 1s/step
----------------------------------------------------------------------------
测试准确率 76.5%
test_probs = gat_model.predict(x=test_indices)
mapping = {v: k for (k, v) in class_idx.items()}
for i, (probs, label) in enumerate(zip(test_probs[:10], test_labels[:10])):
print(f"示例 {i+1}: {mapping[label]}")
for j, c in zip(probs, class_idx.keys()):
print(f"\t{c: <24} 的概率 = {j*100:7.3f}%")
print("---" * 20)
示例 1: 概率方法
案例基础的概率 = 0.919%
遗传算法的概率 = 0.180%
神经网络的概率 = 37.896%
概率方法的概率 = 59.801%
强化学习的概率 = 0.705%
规则学习的概率 = 0.044%
理论的概率 = 0.454%
------------------------------------------------------------
示例 2: 遗传算法
案例基础的概率 = 0.005%
遗传算法的概率 = 99.993%
神经网络的概率 = 0.001%
概率方法的概率 = 0.000%
强化学习的概率 = 0.000%
规则学习的概率 = 0.000%
理论的概率 = 0.000%
------------------------------------------------------------
示例 3: 理论
案例基础的概率 = 8.151%
遗传算法的概率 = 1.021%
神经网络的概率 = 0.569%
概率方法的概率 = 40.220%
强化学习的概率 = 0.792%
规则学习的概率 = 6.910%
理论的概率 = 42.337%
------------------------------------------------------------
示例 4: 神经网络
案例基础的概率 = 0.097%
遗传算法的概率 = 0.026%
神经网络的概率 = 93.539%
概率方法的概率 = 6.206%
强化学习的概率 = 0.028%
规则学习的概率 = 0.010%
理论的概率 = 0.094%
------------------------------------------------------------
示例 5: 理论
案例基础的概率 = 25.259%
遗传算法的概率 = 4.381%
神经网络的概率 = 11.776%
概率方法的概率 = 15.053%
强化学习的概率 = 1.571%
规则学习的概率 = 23.589%
理论的概率 = 18.370%
------------------------------------------------------------
示例 6: 遗传算法
案例基础的概率 = 0.000%
遗传算法的概率 = 100.000%
神经网络的概率 = 0.000%
概率方法的概率 = 0.000%
强化学习的概率 = 0.000%
规则学习的概率 = 0.000%
理论的概率 = 0.000%
------------------------------------------------------------
示例 7: 神经网络
案例基础的概率 = 0.296%
遗传算法的概率 = 0.291%
神经网络的概率 = 93.419%
概率方法的概率 = 5.696%
强化学习的概率 = 0.050%
规则学习的概率 = 0.072%
理论的概率 = 0.177%
------------------------------------------------------------
示例 8: 遗传算法
案例基础的概率 = 0.000%
遗传算法的概率 = 100.000%
神经网络的概率 = 0.000%
概率方法的概率 = 0.000%
强化学习的概率 = 0.000%
规则学习的概率 = 0.000%
理论的概率 = 0.000%
------------------------------------------------------------
示例 9: 理论
案例基础的概率 = 4.103%
遗传算法的概率 = 5.217%
神经网络的概率 = 14.532%
概率方法的概率 = 66.747%
强化学习的概率 = 3.008%
规则学习的概率 = 1.782%
理论的概率 = 4.611%
------------------------------------------------------------
示例 10: 案例基础
案例基础的概率 = 99.566%
遗传算法的概率 = 0.017%
神经网络的概率 = 0.016%
概率方法的概率 = 0.155%
强化学习的概率 = 0.026%
规则学习的概率 = 0.192%
理论的概率 = 0.028%
------------------------------------------------------------
结果看起来不错!GAT模型似乎能够正确预测论文的主题, 根据他们的引用,大约80%的时间。进一步的改进可以通过调整GAT的超参数来实现。例如,尝试更改层数、隐藏单元的数量或优化器/学习率;添加正则化(例如,dropout);或修改预处理步骤。我们还可以尝试实现自环(即,论文X引用论文X)和/或使图无向。