jax.nn.点积注意力

jax.nn.点积注意力#

jax.nn.dot_product_attention(query, key, value, bias=None, mask=None, *, scale=None, is_causal=False, query_seq_lengths=None, key_value_seq_lengths=None, implementation=None)[源代码][源代码]#

缩放点积注意力函数。

计算 Query、Key 和 Value 张量上的注意力函数:

\[\mathrm{Attention}(Q, K, V)=\mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V\]

如果我们定义 logits\(QK^T\) 的输出,并且 probs\(softmax\) 的输出。

在整个函数中,我们使用以下大写字母来表示数组的形状:

B = batch size
S = length of the key/value (source)
T = length of the query (target)
N = number of attention heads
H = dimensions of each attention head
K = number of key/value heads
G = number of groups, which equals to N // K
参数:
  • query (ArrayLike) – 查询数组;形状 (BTNH|TNH)

  • key (ArrayLike) – 键数组:形状 (BSKH|SKH)。当 K 等于 N 时,执行多头注意力(MHA https://arxiv.org/abs/1706.03762)。否则,如果 NK 的倍数,则执行分组查询注意力(GQA https://arxiv.org/abs/2305.13245),如果 `K == 1`(GQA 的特殊情况),则执行多查询注意力(MQA https://arxiv.org/abs/1911.02150)。

  • value (ArrayLike) – 值数组,应与 key 数组具有相同的形状。

  • bias (ArrayLike | None) – 可选,要添加到对数中的偏置数组;形状必须是4D,并且可以广播到 (BNTS|NTS)

  • mask (ArrayLike | None) – 可选的掩码数组,用于过滤掉logits。它是一个布尔掩码,其中`True`表示元素应参与注意力计算。对于加性掩码,用户应将其传递给`bias`。形状必须为4D,并且可广播到 (BNTS|NTS)

  • scale (float | None) – logits 的缩放比例。如果为 None,缩放比例将设置为 1 除以查询的头维度(即 H)的平方根。

  • is_causal (bool) – 如果为真,将应用因果注意力。注意,一些实现如 xla 会生成一个掩码张量并将其应用于对数以屏蔽注意力矩阵的非因果部分,但其他实现如 cudnn 将避免计算非因果区域,从而提供加速。

  • query_seq_lengths (ArrayLike | None) – 查询序列长度的 int32 数组;形状 (B)

  • key_value_seq_lengths (ArrayLike | None) – int32 类型的序列长度数组,用于键和值;形状 (B)

  • implementation (Literal['xla', 'cudnn'] | None) – 用于控制使用哪个实现后端的字符串。支持的字符串有 xlacudnn`(cuDNN 快速注意力)。默认值为 `None,将自动选择最佳可用后端。注意,cudnn 仅支持部分形状/数据类型,如果不支持,将抛出异常。

返回:

query 形状相同的注意力输出数组。

返回类型:

Array