如何破解任何Transformers模型
🤗 Transformers 库提供了一系列预训练模型和工具,用于自然语言处理、视觉等领域。虽然这些模型涵盖了广泛的应用场景,但您可能会遇到一些默认不支持的使用案例。自定义模型可以解锁新的可能性,例如添加新层、改变架构或优化注意力机制。本指南将向您展示如何修改现有的 Transformers 模型以满足您的特定需求。最棒的是,您无需离开 Transformers 框架即可进行这些更改。您可以直接在 Transformers 中修改模型,并仍然利用诸如 Trainer API、PreTrainedModel 以及使用 PEFT 等工具进行高效微调的功能。
在本指南中,我们将引导您如何自定义现有的Transformers模型以满足您的需求,同时不失去生态系统的优势。
你将学习如何:
- 通过改变模型的注意力机制来修改其架构。
- 将低秩适应(LoRA)等技术应用于特定模型组件。
我们鼓励您贡献自己的技巧,并在这里与社区分享
示例:在Segment Anything Model (SAM)中修改注意力机制
Segment Anything Model (SAM) 是一种用于图像分割的最先进模型。在其默认实现中,SAM 在其注意力机制中使用了组合的查询-键-值(qkv
)投影。然而,您可能只想微调注意力机制的特定组件,例如查询(q
)和值(v
)投影,以减少可训练参数和所需的计算资源。
动机
通过将组合的 qkv
投影拆分为单独的 q
、k
和 v
投影,您可以应用像 LoRA(低秩适应)这样的技术,仅对 q
和 v
投影进行操作。这种方法使您能够:
- 微调较少的参数,减少计算开销。
- 通过专注于特定组件,可能实现更好的性能。
- 在注意力机制中尝试不同的适应策略。
实现
步骤1:创建一个自定义注意力类
接下来,子类化原始的 SamVisionAttention
类,并修改它以拥有独立的 q
、k
和 v
投影。
import torch
import torch.nn as nn
from transformers.models.sam.modeling_sam import SamVisionAttention
class SamVisionAttentionSplit(SamVisionAttention, nn.Module):
def __init__(self, config, window_size):
super().__init__(config, window_size)
del self.qkv
# Separate q, k, v projections
self.q = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
self.k = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
self.v = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
self._register_load_state_dict_pre_hook(self.split_q_k_v_load_hook)
def split_q_k_v_load_hook(self, state_dict, prefix, *args):
keys_to_delete = []
for key in list(state_dict.keys()):
if "qkv." in key:
# Split q, k, v from the combined projection
q, k, v = state_dict[key].chunk(3, dim=0)
# Replace with individual q, k, v projections
state_dict[key.replace("qkv.", "q.")] = q
state_dict[key.replace("qkv.", "k.")] = k
state_dict[key.replace("qkv.", "v.")] = v
# Mark the old qkv key for deletion
keys_to_delete.append(key)
# Remove old qkv keys
for key in keys_to_delete:
del state_dict[key]
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
batch_size, height, width, _ = hidden_states.shape
qkv_shapes = (batch_size * self.num_attention_heads, height * width, -1)
query = self.q(hidden_states).reshape((batch_size, height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
key = self.k(hidden_states).reshape((batch_size, height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
value = self.v(hidden_states).reshape((batch_size, height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
if self.use_rel_pos:
attn_weights = self.add_decomposed_rel_pos(
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
)
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
attn_output = self.proj(attn_output)
if output_attentions:
outputs = (attn_output, attn_weights)
else:
outputs = (attn_output, None)
return outputs
解释:
- 分离投影: 移除了组合的
qkv
投影,并创建了单独的q
、k
和v
线性层。 - 权重加载钩子:
_split_qkv_load_hook
方法在加载模型时将预训练的qkv
权重分割为单独的q
、k
和v
权重。这确保了与任何预训练模型的兼容性。 - 前向传播: 查询、键和值分别计算,注意力机制按常规进行。
步骤2:替换原始的注意力类
将原始的 SamVisionAttention
类替换为您的自定义类,以便模型使用修改后的注意力机制。
from transformers import SamModel
from transformers.models.sam import modeling_sam
# Replace the attention class in the modeling_sam module
modeling_sam.SamVisionAttention = SamVisionAttentionSplit
# Load the pre-trained SAM model
model = SamModel.from_pretrained("facebook/sam-vit-base")
解释:
- 类替换: 通过将您的自定义类分配给
modeling_sam.SamVisionAttention
,模型中的任何SamVisionAttention
实例都将使用修改后的版本。因此,当您调用SamModel
时,它将使用新定义的SamVisionAttentionSplit
。 - 模型加载: 使用
from_pretrained
加载模型,并集成了自定义的注意力机制。
步骤3:将LoRA应用于特定投影
通过分离的 q
、k
和 v
投影,您现在可以将 LoRA 应用于特定组件,例如 q
和 v
投影。
from peft import LoraConfig, get_peft_model
config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q", "v"], # Apply LoRA to q and v projections
lora_dropout=0.1,
task_type="mask-generation"
)
# Apply LoRA to the model
model = get_peft_model(model, config)
解释:
- LoRA 配置:
LoraConfig
指定了秩r
、缩放因子lora_alpha
、目标模块("q"
和"v"
)、dropout 和任务类型。 - 应用LoRA:
get_peft_model
函数将LoRA应用于模型中指定的模块。 - 参数减少: 通过关注
q
和v
,你可以减少可训练参数的数量,从而加快训练速度并降低内存使用。
步骤4:验证可训练参数的数量
验证可训练参数的数量并查看您的修改产生了什么影响非常简单。
model.print_trainable_parameters()
预期输出:
trainable params: 608,256 || all params: 94,343,728 || trainable%: 0.6447
trainable params: 912,384 || all params: 94,647,856 || trainable%: 0.9640 # with k
贡献你自己的技巧
修改预训练模型可以为研究和应用开辟新的途径。通过理解和调整像SAM这样的模型的内部机制,您可以根据特定需求进行定制,优化性能,并尝试新的想法。
如果你已经为Transformers模型开发了自己的技巧,并希望分享它们,请考虑为本文档做出贡献。
- 打开一个拉取请求: 直接在仓库中分享你的代码更改和改进。
- 编写文档: 提供清晰的解释和修改的示例。
- 与社区互动:通过提出问题来讨论您的想法,并从其他开发者和研究人员那里获得反馈。