GroupedQueryAttention
classkeras.layers.GroupQueryAttention(
head_dim,
num_query_heads,
num_key_value_heads,
dropout=0.0,
use_bias=True,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs
)
分组查询注意力层.
这是由Ainslie et al., 2023引入的分组查询注意力的实现.这里num_key_value_heads
表示组的数量,将num_key_value_heads
设置为1等效于多查询注意力,而当num_key_value_heads
等于num_query_heads
时,它等效于多头注意力.
该层首先投影query
、key
和value
张量.然后,key
和value
被重复以匹配query
的头数.
然后,query
被缩放并与key
张量进行点积.这些被softmax处理以获得注意力概率.然后,value
张量由这些概率插值并连接回单个张量.
参数: head_dim: 每个注意力头的大小. num_query_heads: 查询注意力头的数量. num_key_value_heads: 键和值注意力头的数量. dropout: Dropout概率. use_bias: 布尔值,是否在密集层中使用偏置向量/矩阵. kernel_initializer: 密集层核的初始化器. bias_initializer: 密集层偏置的初始化器. kernel_regularizer: 密集层核的正则化器. bias_regularizer: 密集层偏置的正则化器. activity_regularizer: 密集层活动的正则化器. kernel_constraint: 密集层核的约束. bias_constraint: 密集层核的约束.
调用参数:
query: 查询张量,形状为(batch_dim, target_seq_len, feature_dim)
,其中batch_dim
是批量大小,target_seq_len
是目标序列的长度,feature_dim
是特征的维度.
value: 值张量,形状为(batch_dim, source_seq_len, feature_dim)
,其中batch_dim
是批量大小,source_seq_len
是源序列的长度,feature_dim
是特征的维度.
key: 可选的键张量,形状为(batch_dim, source_seq_len, feature_dim)
.如果不给定,将使用value
作为key
和value
,这是最常见的情况.
attention_mask: 形状为(batch_dim, target_seq_len, source_seq_len)
的布尔掩码,防止对某些位置的注意力.布尔掩码指定哪些查询元素可以关注哪些键元素,其中1表示关注,0表示不关注.缺失的批量维度和头维度可以进行广播.
return_attention_scores: 一个布尔值,指示输出应该是(attention_output, attention_scores)
(如果为True
),还是attention_output
(如果为False
).默认为False
.
training: Python布尔值,指示层是否应在训练模式(添加dropout)或推理模式(无dropout)下运行.将使用父层/模型的训练模式或False
(推理),如果没有父层.
use_causal_mask: 一个布尔值,指示是否应用因果掩码以防止令牌关注未来令牌(例如,在解码器Transformer中使用).
返回:
attention_output: 计算结果,形状为(batch_dim, target_seq_len, feature_dim)
,其中target_seq_len
是目标序列长度,feature_dim
是查询输入的最后一个维度.
attention_scores: (可选)注意力系数,形状为(batch_dim, num_query_heads, target_seq_len, source_seq_len)
.