Keras 3 API 文档 / 层 API / 注意力层 / 多头注意力层

多头注意力层

[source]

MultiHeadAttention class

keras.layers.MultiHeadAttention(
    num_heads,
    key_dim,
    value_dim=None,
    dropout=0.0,
    use_bias=True,
    output_shape=None,
    attention_axes=None,
    kernel_initializer="glorot_uniform",
    bias_initializer="zeros",
    kernel_regularizer=None,
    bias_regularizer=None,
    activity_regularizer=None,
    kernel_constraint=None,
    bias_constraint=None,
    seed=None,
    **kwargs
)

多头注意力层.

这是在论文《注意力就是你所需要的一切》Vaswani et al., 2017中描述的多头注意力的实现. 如果querykeyvalue相同,那么这就是自注意力.query中的每个时间步都会关注key中相应的序列,并返回一个固定宽度的向量.

该层首先投影querykeyvalue.这些实际上是一个长度为num_attention_heads的张量列表,对应的形状为(batch_size, <query dimensions>, key_dim)(batch_size, <key/value dimensions>, key_dim)(batch_size, <key/value dimensions>, value_dim).

然后,查询和键张量进行点积并缩放.这些经过softmax处理以获得注意力概率.然后,值张量根据这些概率进行插值,然后连接回单个张量.

最后,结果张量的最后一个维度为value_dim,可以进行线性投影并返回.

参数: num_heads: 注意力头的数量. key_dim: 查询和键的每个注意力头的大小. value_dim: 值的每个注意力头的大小. dropout: dropout概率. use_bias: 布尔值,表示密集层是否使用偏置向量/矩阵. output_shape: 输出张量的预期形状,除了批次和序列维度.如果未指定,则投影回查询特征维度(查询输入的最后一个维度). attention_axes: 应用注意力的轴.None表示对所有轴进行注意力,但不包括批次、头和特征. kernel_initializer: 密集层核的初始化器. bias_initializer: 密集层偏置的初始化器. kernel_regularizer: 密集层核的正则化器. bias_regularizer: 密集层偏置的正则化器. activity_regularizer: 密集层活动的正则化器. kernel_constraint: 密集层核的约束. bias_constraint: 密集层核的约束. seed: 用于dropout层的可选整数种子.

调用参数: query: 查询张量的形状为(B, T, dim),其中B是批次大小,T是目标序列长度,dim是特征维度. value: 值张量的形状为(B, S, dim),其中B是批次大小,S是源序列长度,dim是特征维度. key: 可选的键张量的形状为(B, S, dim).如果未给出,将使用value作为keyvalue,这是最常见的情况. attention_mask: 形状为(B, T, S)的布尔掩码,防止对某些位置进行注意力.布尔掩码指定哪些查询元素可以关注哪些键元素,1表示注意力,0表示无注意力.可以对缺失的批次维度和头维度进行广播. return_attention_scores: 一个布尔值,指示输出应该是(attention_output, attention_scores)(如果为True),还是attention_output(如果为False).默认为False. training: Python布尔值,指示层是否应在训练模式(添加dropout)或推理模式(无dropout)下运行.将使用父层/模型的训练模式,或者如果没有父层,则使用False(推理). use_causal_mask: 一个布尔值,指示是否应用因果掩码以防止令牌关注未来令牌(例如,在解码器Transformer中使用).

返回: attention_output: 计算结果,形状为(B, T, E),其中T是目标序列形状,E是查询输入的最后一个维度(如果output_shapeNone).否则,多头的输出将投影到由output_shape指定的形状. attention_scores: (可选)注意力轴上的多头注意力系数.