torch_geometric.nn.models.GLEM

class GLEM(lm_to_use: str = 'prajjwal1/bert-tiny', gnn_to_use: <module 'torch_geometric.nn.models.basic_gnn' from '/home/docs/checkouts/readthedocs.org/user_builds/pytorch-geometric/envs/latest/lib/python3.9/site-packages/torch_geometric/nn/models/basic_gnn.py'> = <class 'torch_geometric.nn.models.basic_gnn.GraphSAGE'>, out_channels: int = 47, gnn_loss=CrossEntropyLoss(), lm_loss=CrossEntropyLoss(), alpha: float = 0.5, beta: float = 0.5, lm_dtype: ~torch.dtype = torch.bfloat16, lm_use_lora: bool = True, lora_target_modules: ~typing.Optional[~typing.Union[str, ~typing.List[str]]] = None, device: ~typing.Union[str, ~torch.device] = device(type='cpu'))[source]

Bases: Module

这个GNN+LM联合训练模型基于来自“通过变分推理在大规模文本属性图上的学习”论文的GLEM。

Parameters:
  • lm_to_use (str) – 来自huggingface模型仓库的TextEncoder 带有分类器(默认:TinyBERT)

  • gnn_to_use (torch_geometric.nn.models) – (默认: GraphSAGE)

  • out_channels (int) – LM和GNN的输出通道数,应该相同

  • 可选[int] (num_gnn_heads) – 如果需要,注意力机制的头数

  • num_gnn_layers (int) – GNN层的数量

  • gnn_loss – gnn的损失函数, (默认: CrossEntropyLoss)

  • lm_loss – 语言模型的损失函数,(默认: CrossEntropyLoss)

  • alpha (float) – E步骤的伪标签权重,LM优化, (默认值: 0.5)

  • beta (float) – M步的伪标签权重,GNN优化, (默认值: 0.5)

  • lm_dtype (torch.dtype) – 将LM加载到内存后的数据类型, (默认: torch.bfloat16)

  • lm_use_lora (bool) – 选择是否使用Lora peft进行微调, (默认值: True)

  • lora_target_modules (Union[str, List[str], None], 默认值: None) – 应用lora适配器的目标模块的名称,例如对于LLM,可以是[‘q_proj’, ‘v_proj’],(默认值: None)

注意

请参阅examples/llm_plus_gnn/glem.py以查看示例用法。

forward(*input: Any) None

定义每次调用时执行的计算。

应该由所有子类覆盖。

注意

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type:

None

train(em_phase: str, train_loader: Union[DataLoader, NeighborLoader], optimizer: Optimizer, pseudo_labels: Tensor, epoch: int, is_augmented: bool = False, verbose: bool = False)[source]

GLEM 训练步骤,EM 步骤。

Parameters:
  • em_phase (str) – ‘gnn’ 或 ‘lm’ 选择你正在训练的哪个阶段

  • train_loader (Union[DataLoader, NeighborLoader]) – 使用 DataLoader 进行 lm 训练,包括标记化数据、标签 is_gold 掩码。 使用 NeighborLoader 进行 gnn 训练,包括 x、edge_index。

  • optimizer (torch.optim.Optimizer) – 用于训练的优化器

  • pseudo_labels (torch.Tensor) – 用作伪标签的预测标签

  • epoch (int) – 当前周期

  • is_augmented (bool) – 是否使用伪标签

  • verbose (bool) – 是否打印训练进度条

Returns:

训练准确率 损失 (float): 损失值

Return type:

acc (float)

train_lm(train_loader: DataLoader, optimizer: Optimizer, epoch: int, pseudo_labels: Optional[Tensor] = None, is_augmented: bool = False, verbose: bool = True)[source]

语言模型在每个epoch中进行训练。

Parameters:
Returns:

训练准确率 损失 (torch.float): 损失值

Return type:

approx_acc (torch.tensor)

train_gnn(train_loader: NeighborLoader, optimizer: Optimizer, epoch: int, pseudo_labels: Optional[Tensor] = None, is_augmented: bool = False, verbose: bool = True)[source]

每个epoch中的GNN训练步骤。

Parameters:
  • train_loader (loader.NeighborLoader) – GNN 邻居节点加载器

  • optimizer (torch.optim.Optimizer) – model optimizer

  • epoch (int) – 当前训练周期

  • pseudo_labels (torch.tensor) – 一维张量,来自lm的预测

  • is_augmented (bool) – 是否使用伪标记节点

  • verbose (bool) – 是否打印训练进度

Returns:

训练准确率 损失 (torch.float): 损失值

Return type:

approx_acc (torch.tensor)

inference(em_phase: str, data_loader: Union[NeighborLoader, DataLoader], verbose: bool = False)[source]

GLEM推理步骤。

Parameters:
  • em_phase (str) – ‘gnn’ 或 ‘lm’

  • data_loader (dataloaderNeighborloader) – dataloader: 用于语言模型训练,包含分词后的数据 nodeloader: 用于图神经网络训练,包含x, edge_index

  • verbose (bool) – 是否打印推理进度

Returns:

n * m 张量,m 是类别数量,

n 是节点数量

Return type:

输出 (torch.Tensor)

inference_lm(data_loader: DataLoader, verbose: bool = True)[source]

LM推理步骤。

Parameters:
  • data_loader (Dataloader) – 包含令牌、标签和黄金掩码

  • verbose (bool) – 是否打印进度条

Returns:

从GNN预测,转换为伪标签

通过 preds.argmax(dim=-1).unsqueeze(1)

Return type:

预测值 (张量)

inference_gnn(data_loader: NeighborLoader, verbose: bool = True)[source]

GNN推理步骤。

Parameters:
  • data_loader (NeighborLoader) – 包含 x, edge_index,

  • verbose (bool) – 是否打印进度条

Returns:

来自GNN的预测,

通过preds.argmax(dim=-1).unsqueeze(1)转换为伪标签

Return type:

预测值 (张量)

loss(logits: ~torch.Tensor, labels: ~torch.Tensor, loss_func: <module 'torch.nn.functional' from '/home/docs/checkouts/readthedocs.org/user_builds/pytorch-geometric/envs/latest/lib/python3.9/site-packages/torch/nn/functional.py'>, is_gold: ~torch.Tensor, pseudo_labels: ~typing.Optional[~torch.Tensor] = None, pl_weight: float = 0.5, is_augmented: bool = True)[source]

变分EM推断的核心功能,此函数旨在结合黄金(原始训练)上的损失值和伪标签上的损失值。

参考: <https://github.com/AndyJZhao/GLEM/blob/main/src/models/GLEM/GLEM_utils.py> # noqa

Parameters:
  • logits (torch.tensor) – 来自LM或GNN的预测结果

  • labels (torch.tensor) – 来自真实标签和伪标签(如果提供)的组合节点标签

  • loss_func (torch.nn.modules.loss) – 用于分类的损失函数

  • is_gold (tensor) – 一个带有布尔值的张量,用于在训练期间屏蔽真实标签,因此 ~is_gold 屏蔽伪标签

  • pseudo_labels (torch.tensor) – 来自其他模型的预测

  • pl_weight (float, 默认: 0.5) – 在E步和M步优化中使用的伪标签 在E步中的alpha,在M步中的beta

  • is_augmented (bool, 默认: True) – 使用EM或仅使用黄金数据训练GNN和LM