Transformers 文档

如何破解任何Transformers模型

如何破解任何Transformers模型

🤗 Transformers 库提供了一系列预训练模型和工具,用于自然语言处理、视觉等领域。虽然这些模型涵盖了广泛的应用场景,但您可能会遇到一些默认不支持的使用案例。自定义模型可以解锁新的可能性,例如添加新层、改变架构或优化注意力机制。本指南将向您展示如何修改现有的 Transformers 模型以满足您的特定需求。最棒的是,您无需离开 Transformers 框架即可进行这些更改。您可以直接在 Transformers 中修改模型,并仍然利用诸如 Trainer APIPreTrainedModel 以及使用 PEFT 等工具进行高效微调的功能。

在本指南中,我们将引导您如何自定义现有的Transformers模型以满足您的需求,同时不失去生态系统的优势。

你将学习如何:

  • 通过改变模型的注意力机制来修改其架构。
  • 将低秩适应(LoRA)等技术应用于特定模型组件。

我们鼓励您贡献自己的技巧,并在这里与社区分享

示例:在Segment Anything Model (SAM)中修改注意力机制

Segment Anything Model (SAM) 是一种用于图像分割的最先进模型。在其默认实现中,SAM 在其注意力机制中使用了组合的查询-键-值(qkv)投影。然而,您可能只想微调注意力机制的特定组件,例如查询(q)和值(v)投影,以减少可训练参数和所需的计算资源。

动机

通过将组合的 qkv 投影拆分为单独的 qkv 投影,您可以应用像 LoRA(低秩适应)这样的技术,仅对 qv 投影进行操作。这种方法使您能够:

  • 微调较少的参数,减少计算开销。
  • 通过专注于特定组件,可能实现更好的性能。
  • 在注意力机制中尝试不同的适应策略。

实现

步骤1:创建一个自定义注意力类

接下来,子类化原始的 SamVisionAttention 类,并修改它以拥有独立的 qkv 投影。

import torch
import torch.nn as nn
from transformers.models.sam.modeling_sam import SamVisionAttention

class SamVisionAttentionSplit(SamVisionAttention, nn.Module):
    def __init__(self, config, window_size):
        super().__init__(config, window_size)
        del self.qkv
        # Separate q, k, v projections
        self.q = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        self.k = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        self.v = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        self._register_load_state_dict_pre_hook(self.split_q_k_v_load_hook)

    def split_q_k_v_load_hook(self, state_dict, prefix, *args):
        keys_to_delete = []
        for key in list(state_dict.keys()):
            if "qkv." in key:
                # Split q, k, v from the combined projection
                q, k, v = state_dict[key].chunk(3, dim=0)
                # Replace with individual q, k, v projections
                state_dict[key.replace("qkv.", "q.")] = q
                state_dict[key.replace("qkv.", "k.")] = k
                state_dict[key.replace("qkv.", "v.")] = v
                # Mark the old qkv key for deletion
                keys_to_delete.append(key)
        
        # Remove old qkv keys
        for key in keys_to_delete:
            del state_dict[key]

    def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
        batch_size, height, width, _ = hidden_states.shape
        qkv_shapes = (batch_size *  self.num_attention_heads,  height * width, -1)
        query = self.q(hidden_states).reshape((batch_size,  height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
        key = self.k(hidden_states).reshape((batch_size,  height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
        value = self.v(hidden_states).reshape((batch_size,  height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)

        attn_weights = (query * self.scale) @ key.transpose(-2, -1)

        if self.use_rel_pos:
            attn_weights = self.add_decomposed_rel_pos(
                attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
            )

        attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
        attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
        attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
        attn_output = self.proj(attn_output)

        if output_attentions:
            outputs = (attn_output, attn_weights)
        else:
            outputs = (attn_output, None)
        return outputs

解释:

  • 分离投影: 移除了组合的 qkv 投影,并创建了单独的 qkv 线性层。
  • 权重加载钩子: _split_qkv_load_hook 方法在加载模型时将预训练的 qkv 权重分割为单独的 qkv 权重。这确保了与任何预训练模型的兼容性。
  • 前向传播: 查询、键和值分别计算,注意力机制按常规进行。

步骤2:替换原始的注意力类

将原始的 SamVisionAttention 类替换为您的自定义类,以便模型使用修改后的注意力机制。

from transformers import SamModel
from transformers.models.sam import modeling_sam

# Replace the attention class in the modeling_sam module
modeling_sam.SamVisionAttention = SamVisionAttentionSplit

# Load the pre-trained SAM model
model = SamModel.from_pretrained("facebook/sam-vit-base")

解释:

  • 类替换: 通过将您的自定义类分配给 modeling_sam.SamVisionAttention,模型中的任何 SamVisionAttention 实例都将使用修改后的版本。因此,当您调用 SamModel 时,它将使用新定义的 SamVisionAttentionSplit
  • 模型加载: 使用 from_pretrained 加载模型,并集成了自定义的注意力机制。

步骤3:将LoRA应用于特定投影

通过分离的 qkv 投影,您现在可以将 LoRA 应用于特定组件,例如 qv 投影。

from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q", "v"],  # Apply LoRA to q and v projections
    lora_dropout=0.1,
    task_type="mask-generation"
)

# Apply LoRA to the model
model = get_peft_model(model, config)

解释:

  • LoRA 配置: LoraConfig 指定了秩 r、缩放因子 lora_alpha、目标模块("q""v")、dropout 和任务类型。
  • 应用LoRA: get_peft_model 函数将LoRA应用于模型中指定的模块。
  • 参数减少: 通过关注 qv,你可以减少可训练参数的数量,从而加快训练速度并降低内存使用。

步骤4:验证可训练参数的数量

验证可训练参数的数量并查看您的修改产生了什么影响非常简单。

model.print_trainable_parameters()

预期输出:

trainable params: 608,256 || all params: 94,343,728 || trainable%: 0.6447
trainable params: 912,384 || all params: 94,647,856 || trainable%: 0.9640 # with k 

贡献你自己的技巧

修改预训练模型可以为研究和应用开辟新的途径。通过理解和调整像SAM这样的模型的内部机制,您可以根据特定需求进行定制,优化性能,并尝试新的想法。

如果你已经为Transformers模型开发了自己的技巧,并希望分享它们,请考虑为本文档做出贡献。

  • 打开一个拉取请求: 直接在仓库中分享你的代码更改和改进。
  • 编写文档: 提供清晰的解释和修改的示例。
  • 与社区互动:通过提出问题来讨论您的想法,并从其他开发者和研究人员那里获得反馈。
< > Update on GitHub