torch_geometric.nn.models.GLEM
- class GLEM(lm_to_use: str = 'prajjwal1/bert-tiny', gnn_to_use: <module 'torch_geometric.nn.models.basic_gnn' from '/home/docs/checkouts/readthedocs.org/user_builds/pytorch-geometric/envs/latest/lib/python3.9/site-packages/torch_geometric/nn/models/basic_gnn.py'> = <class 'torch_geometric.nn.models.basic_gnn.GraphSAGE'>, out_channels: int = 47, gnn_loss=CrossEntropyLoss(), lm_loss=CrossEntropyLoss(), alpha: float = 0.5, beta: float = 0.5, lm_dtype: ~torch.dtype = torch.bfloat16, lm_use_lora: bool = True, lora_target_modules: ~typing.Optional[~typing.Union[str, ~typing.List[str]]] = None, device: ~typing.Union[str, ~torch.device] = device(type='cpu'))[source]
Bases:
Module这个GNN+LM联合训练模型基于来自“通过变分推理在大规模文本属性图上的学习”论文的GLEM。
- Parameters:
lm_to_use (str) – 来自huggingface模型仓库的TextEncoder 带有分类器(默认:TinyBERT)
gnn_to_use (torch_geometric.nn.models) – (默认: GraphSAGE)
out_channels (int) – LM和GNN的输出通道数,应该相同
可选[int] (num_gnn_heads) – 如果需要,注意力机制的头数
num_gnn_layers (int) – GNN层的数量
gnn_loss – gnn的损失函数, (默认: CrossEntropyLoss)
lm_loss – 语言模型的损失函数,(默认: CrossEntropyLoss)
alpha (float) – E步骤的伪标签权重,LM优化, (默认值: 0.5)
beta (float) – M步的伪标签权重,GNN优化, (默认值: 0.5)
lm_dtype (torch.dtype) – 将LM加载到内存后的数据类型, (默认: torch.bfloat16)
lm_use_lora (bool) – 选择是否使用Lora peft进行微调, (默认值: True)
lora_target_modules (
Union[str,List[str],None], 默认值:None) – 应用lora适配器的目标模块的名称,例如对于LLM,可以是[‘q_proj’, ‘v_proj’],(默认值: None)
注意
请参阅examples/llm_plus_gnn/glem.py以查看示例用法。
- forward(*input: Any) None
定义每次调用时执行的计算。
应该由所有子类覆盖。
注意
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Return type:
- train(em_phase: str, train_loader: Union[DataLoader, NeighborLoader], optimizer: Optimizer, pseudo_labels: Tensor, epoch: int, is_augmented: bool = False, verbose: bool = False)[source]
GLEM 训练步骤,EM 步骤。
- Parameters:
em_phase (str) – ‘gnn’ 或 ‘lm’ 选择你正在训练的哪个阶段
train_loader (Union[DataLoader, NeighborLoader]) – 使用 DataLoader 进行 lm 训练,包括标记化数据、标签 is_gold 掩码。 使用 NeighborLoader 进行 gnn 训练,包括 x、edge_index。
optimizer (torch.optim.Optimizer) – 用于训练的优化器
pseudo_labels (torch.Tensor) – 用作伪标签的预测标签
epoch (int) – 当前周期
is_augmented (bool) – 是否使用伪标签
verbose (bool) – 是否打印训练进度条
- Returns:
训练准确率 损失 (float): 损失值
- Return type:
acc (float)
- train_lm(train_loader: DataLoader, optimizer: Optimizer, epoch: int, pseudo_labels: Optional[Tensor] = None, is_augmented: bool = False, verbose: bool = True)[source]
语言模型在每个epoch中进行训练。
- Parameters:
train_loader (loader.dataloader.DataLoader) – 文本数据加载器
optimizer (torch.optim.Optimizer) – 模型优化器
epoch (int) – 当前训练周期
pseudo_labels (torch.Tensor) – 1维张量,来自gnn的预测
is_augmented (bool) – 是否使用伪标签进行训练
verbose (bool) – 是否打印训练进度条
- Returns:
训练准确率 损失 (torch.float): 损失值
- Return type:
approx_acc (torch.tensor)
- train_gnn(train_loader: NeighborLoader, optimizer: Optimizer, epoch: int, pseudo_labels: Optional[Tensor] = None, is_augmented: bool = False, verbose: bool = True)[source]
每个epoch中的GNN训练步骤。
- Parameters:
train_loader (loader.NeighborLoader) – GNN 邻居节点加载器
optimizer (torch.optim.Optimizer) – model optimizer
epoch (int) – 当前训练周期
pseudo_labels (torch.tensor) – 一维张量,来自lm的预测
is_augmented (bool) – 是否使用伪标记节点
verbose (bool) – 是否打印训练进度
- Returns:
训练准确率 损失 (torch.float): 损失值
- Return type:
approx_acc (torch.tensor)
- inference(em_phase: str, data_loader: Union[NeighborLoader, DataLoader], verbose: bool = False)[source]
GLEM推理步骤。
- Parameters:
- Returns:
- n * m 张量,m 是类别数量,
n 是节点数量
- Return type:
输出 (torch.Tensor)
- inference_lm(data_loader: DataLoader, verbose: bool = True)[source]
LM推理步骤。
- Parameters:
data_loader (Dataloader) – 包含令牌、标签和黄金掩码
verbose (bool) – 是否打印进度条
- Returns:
- 从GNN预测,转换为伪标签
通过 preds.argmax(dim=-1).unsqueeze(1)
- Return type:
预测值 (张量)
- inference_gnn(data_loader: NeighborLoader, verbose: bool = True)[source]
GNN推理步骤。
- Parameters:
data_loader (NeighborLoader) – 包含 x, edge_index,
verbose (bool) – 是否打印进度条
- Returns:
- 来自GNN的预测,
通过preds.argmax(dim=-1).unsqueeze(1)转换为伪标签
- Return type:
预测值 (张量)
- loss(logits: ~torch.Tensor, labels: ~torch.Tensor, loss_func: <module 'torch.nn.functional' from '/home/docs/checkouts/readthedocs.org/user_builds/pytorch-geometric/envs/latest/lib/python3.9/site-packages/torch/nn/functional.py'>, is_gold: ~torch.Tensor, pseudo_labels: ~typing.Optional[~torch.Tensor] = None, pl_weight: float = 0.5, is_augmented: bool = True)[source]
变分EM推断的核心功能,此函数旨在结合黄金(原始训练)上的损失值和伪标签上的损失值。
参考: <https://github.com/AndyJZhao/GLEM/blob/main/src/models/GLEM/GLEM_utils.py> # noqa
- Parameters:
logits (torch.tensor) – 来自LM或GNN的预测结果
labels (torch.tensor) – 来自真实标签和伪标签(如果提供)的组合节点标签
loss_func (torch.nn.modules.loss) – 用于分类的损失函数
is_gold (tensor) – 一个带有布尔值的张量,用于在训练期间屏蔽真实标签,因此 ~is_gold 屏蔽伪标签
pseudo_labels (torch.tensor) – 来自其他模型的预测
pl_weight (
float, 默认:0.5) – 在E步和M步优化中使用的伪标签 在E步中的alpha,在M步中的betais_augmented (
bool, 默认:True) – 使用EM或仅使用黄金数据训练GNN和LM