Keras 3 API 文档 / 层 API / 注意力层 / 组查询注意力

组查询注意力

[source]

GroupedQueryAttention class

keras.layers.GroupQueryAttention(
    head_dim,
    num_query_heads,
    num_key_value_heads,
    dropout=0.0,
    use_bias=True,
    kernel_initializer="glorot_uniform",
    bias_initializer="zeros",
    kernel_regularizer=None,
    bias_regularizer=None,
    activity_regularizer=None,
    kernel_constraint=None,
    bias_constraint=None,
    **kwargs
)

分组查询注意力层.

这是由Ainslie et al., 2023引入的分组查询注意力的实现.这里num_key_value_heads表示组的数量,将num_key_value_heads设置为1等效于多查询注意力,而当num_key_value_heads等于num_query_heads时,它等效于多头注意力.

该层首先投影querykeyvalue张量.然后,keyvalue被重复以匹配query的头数.

然后,query被缩放并与key张量进行点积.这些被softmax处理以获得注意力概率.然后,value张量由这些概率插值并连接回单个张量.

参数: head_dim: 每个注意力头的大小. num_query_heads: 查询注意力头的数量. num_key_value_heads: 键和值注意力头的数量. dropout: Dropout概率. use_bias: 布尔值,是否在密集层中使用偏置向量/矩阵. kernel_initializer: 密集层核的初始化器. bias_initializer: 密集层偏置的初始化器. kernel_regularizer: 密集层核的正则化器. bias_regularizer: 密集层偏置的正则化器. activity_regularizer: 密集层活动的正则化器. kernel_constraint: 密集层核的约束. bias_constraint: 密集层核的约束.

调用参数: query: 查询张量,形状为(batch_dim, target_seq_len, feature_dim),其中batch_dim是批量大小,target_seq_len是目标序列的长度,feature_dim是特征的维度. value: 值张量,形状为(batch_dim, source_seq_len, feature_dim),其中batch_dim是批量大小,source_seq_len是源序列的长度,feature_dim是特征的维度. key: 可选的键张量,形状为(batch_dim, source_seq_len, feature_dim).如果不给定,将使用value作为keyvalue,这是最常见的情况. attention_mask: 形状为(batch_dim, target_seq_len, source_seq_len)的布尔掩码,防止对某些位置的注意力.布尔掩码指定哪些查询元素可以关注哪些键元素,其中1表示关注,0表示不关注.缺失的批量维度和头维度可以进行广播. return_attention_scores: 一个布尔值,指示输出应该是(attention_output, attention_scores)(如果为True),还是attention_output(如果为False).默认为False. training: Python布尔值,指示层是否应在训练模式(添加dropout)或推理模式(无dropout)下运行.将使用父层/模型的训练模式或False(推理),如果没有父层. use_causal_mask: 一个布尔值,指示是否应用因果掩码以防止令牌关注未来令牌(例如,在解码器Transformer中使用).

返回: attention_output: 计算结果,形状为(batch_dim, target_seq_len, feature_dim),其中target_seq_len是目标序列长度,feature_dim是查询输入的最后一个维度. attention_scores: (可选)注意力系数,形状为(batch_dim, num_query_heads, target_seq_len, source_seq_len).