MultiHeadAttention
classkeras.layers.MultiHeadAttention(
num_heads,
key_dim,
value_dim=None,
dropout=0.0,
use_bias=True,
output_shape=None,
attention_axes=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
seed=None,
**kwargs
)
多头注意力层.
这是在论文《注意力就是你所需要的一切》Vaswani et al., 2017中描述的多头注意力的实现.
如果query
、key
、value
相同,那么这就是自注意力.query
中的每个时间步都会关注key
中相应的序列,并返回一个固定宽度的向量.
该层首先投影query
、key
和value
.这些实际上是一个长度为num_attention_heads
的张量列表,对应的形状为(batch_size, <query dimensions>, key_dim)
、(batch_size, <key/value dimensions>, key_dim)
、(batch_size, <key/value dimensions>, value_dim)
.
然后,查询和键张量进行点积并缩放.这些经过softmax处理以获得注意力概率.然后,值张量根据这些概率进行插值,然后连接回单个张量.
最后,结果张量的最后一个维度为value_dim
,可以进行线性投影并返回.
参数:
num_heads: 注意力头的数量.
key_dim: 查询和键的每个注意力头的大小.
value_dim: 值的每个注意力头的大小.
dropout: dropout概率.
use_bias: 布尔值,表示密集层是否使用偏置向量/矩阵.
output_shape: 输出张量的预期形状,除了批次和序列维度.如果未指定,则投影回查询特征维度(查询输入的最后一个维度).
attention_axes: 应用注意力的轴.None
表示对所有轴进行注意力,但不包括批次、头和特征.
kernel_initializer: 密集层核的初始化器.
bias_initializer: 密集层偏置的初始化器.
kernel_regularizer: 密集层核的正则化器.
bias_regularizer: 密集层偏置的正则化器.
activity_regularizer: 密集层活动的正则化器.
kernel_constraint: 密集层核的约束.
bias_constraint: 密集层核的约束.
seed: 用于dropout层的可选整数种子.
调用参数:
query: 查询张量的形状为(B, T, dim)
,其中B
是批次大小,T
是目标序列长度,dim是特征维度.
value: 值张量的形状为(B, S, dim)
,其中B
是批次大小,S
是源序列长度,dim是特征维度.
key: 可选的键张量的形状为(B, S, dim)
.如果未给出,将使用value
作为key
和value
,这是最常见的情况.
attention_mask: 形状为(B, T, S)
的布尔掩码,防止对某些位置进行注意力.布尔掩码指定哪些查询元素可以关注哪些键元素,1表示注意力,0表示无注意力.可以对缺失的批次维度和头维度进行广播.
return_attention_scores: 一个布尔值,指示输出应该是(attention_output, attention_scores)
(如果为True
),还是attention_output
(如果为False
).默认为False
.
training: Python布尔值,指示层是否应在训练模式(添加dropout)或推理模式(无dropout)下运行.将使用父层/模型的训练模式,或者如果没有父层,则使用False
(推理).
use_causal_mask: 一个布尔值,指示是否应用因果掩码以防止令牌关注未来令牌(例如,在解码器Transformer中使用).
返回:
attention_output: 计算结果,形状为(B, T, E)
,其中T
是目标序列形状,E
是查询输入的最后一个维度(如果output_shape
为None
).否则,多头的输出将投影到由output_shape
指定的形状.
attention_scores: (可选)注意力轴上的多头注意力系数.