vLLM 分页注意力#
目前,vLLM 使用其自己的多头查询注意力内核实现(
csrc/attention/attention_kernels.cu
)。该内核设计为与 vLLM 的分页 KV 缓存兼容,其中键和值缓存存储在单独的块中(注意,此块概念与 GPU 线程块不同。因此,在后续文档中,我将 vLLM 分页注意力块称为“块”,而将 GPU 线程块称为“线程块”)。为了实现高性能,此内核依赖于一种特别设计的内存布局和访问方法,特别是在线程从全局内存读取数据到共享内存时。本文档的目的是逐步提供内核实现的高层次解释,帮助那些希望了解 vLLM 多头查询注意力内核的人。阅读完本文档后,用户可能会对实际实现有更好的理解,并感到更容易跟随。
请注意,本文档可能不会涵盖所有细节,例如如何计算相应数据的正确索引或点乘法的实现。然而,在阅读本文档并熟悉高级逻辑流程后,您应该更容易阅读实际代码并理解细节。
输入#
内核函数接受当前线程执行其分配工作的参数列表。最重要的三个参数是输入指针
q
、k_cache
和v_cache
,它们指向全局内存中需要读取和处理的查询、键和值数据。输出指针out
指向全局内存,结果应写入该位置。这四个指针实际上指的是多维数组,但每个线程仅访问分配给它的数据部分。为了简化,我省略了此处所有其他运行时参数。template< typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, int PARTITION_SIZE = 0> __device__ void paged_attention_kernel( ... // Other side args. const scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] ... // Other side args. )
在函数签名上方也有一组模板参数,这些参数在编译时确定。
scalar_t
表示查询、键和值数据元素的数据类型,例如 FP16。HEAD_SIZE
表示每个头中的元素数量。BLOCK_SIZE
指的是每个块中的令牌数量。NUM_THREADS
表示每个线程块中的线程数量。PARTITION_SIZE
表示张量并行 GPU 的数量(为简单起见,我们假设这是 0 并且张量并行被禁用)。通过这些参数,我们需要执行一系列准备工作。这包括计算当前的头部索引、块索引和其他必要的变量。然而,目前我们可以忽略这些准备工作,直接进行实际的计算。一旦我们掌握了整个流程,理解它们将会更容易。
概念#
在我们深入计算流程之前,我想先描述一些后续章节所需的概念。不过,如果你遇到任何令人困惑的术语,你可以跳过这一部分,稍后再回来。
序列:序列代表一个客户端请求。例如,由
q
指向的数据具有[num_seqs, num_heads, head_size]
的形状。这表示总共有num_seqs
个查询序列数据由q
指向。由于此内核是单查询注意力内核,每个序列只有一个查询令牌。因此,num_seqs
等于在批次中处理的总令牌数。上下文:上下文由序列生成的标记组成。例如,
["What", "is", "your"]
是上下文标记,输入查询标记是"name"
。模型可能会生成标记"?"
。Vec: vec 是一个元素列表,这些元素被一起获取和计算。对于查询和键数据,vec 大小 (
VEC_SIZE
) 被确定,以便每个线程组可以一次获取和计算 16 字节的数据。对于值数据,vec 大小 (V_VEC_SIZE
) 被确定,以便每个线程可以一次获取和计算 16 字节的数据。例如,如果scalar_t
是 FP16(2 字节)且THREAD_GROUP_SIZE
是 2,那么VEC_SIZE
将是 4,而V_VEC_SIZE
将是 8。线程组: 线程组是一小组线程(
THREAD_GROUP_SIZE
),一次获取并计算一个查询令牌和一个键令牌。每个线程只处理令牌数据的一部分。一个线程组处理的总元素数称为x
。例如,如果线程组包含 2 个线程,头大小为 8,那么线程 0 处理索引为 0、2、4、6 的查询和键元素,而线程 1 处理索引为 1、3、5、7 的元素。块:vLLM中的键和值缓存数据被分割成块。每个块存储一个头中固定数量(
BLOCK_SIZE
)的标记数据。每个块可能只包含整个上下文标记的一部分。例如,如果块大小为16,头大小为128,那么对于一个头,一个块可以存储16 * 128 = 2048个元素。Warp: Warp 是一组 32 个线程(
WARP_SIZE
),它们在流多处理器(SM)上同时执行。在这个内核中,每个 warp 一次处理一个查询令牌与一个完整块的关键令牌之间的计算(它可能在多次迭代中处理多个块)。例如,如果有 4 个 warp 和 6 个块用于一个上下文,分配将类似于 warp 0 处理第 0 和第 4 块,warp 1 处理第 1 和第 5 块,warp 2 处理第 2 块,warp 3 处理第 3 块。线程块:线程块是一组线程(
NUM_THREADS
),可以访问相同的共享内存。每个线程块包含多个线程束(NUM_WARPS
),在此内核中,每个线程块处理一个查询令牌与整个上下文的关键令牌之间的计算。网格:网格是线程块的集合,并定义了集合的形状。在这个内核中,形状是
(num_heads, num_seqs, max_num_partitions)
。因此,每个线程块仅处理一个头、一个序列和一个分区的计算。
查询#
本节将介绍查询数据如何在内存中存储并通过每个线程获取。如上所述,每个线程组获取一个查询令牌数据,而每个线程本身仅处理一个查询令牌数据的一部分。在每个warp中,每个线程组将获取相同的查询令牌数据,但会将其与不同的键令牌数据相乘。
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
每个线程定义了自己的
q_ptr
,它指向全局内存中分配的查询令牌数据。例如,如果VEC_SIZE
是 4 且HEAD_SIZE
是 128,那么q_ptr
指向的数据包含总共 128 个元素,这些元素被分成 128 / 4 = 32 个向量。__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
接下来,我们需要将
q_ptr
指向的全局内存数据读入共享内存中,作为q_vecs
。需要注意的是,每个 vecs 被分配到不同的行。例如,如果THREAD_GROUP_SIZE
是 2,线程 0 将处理第 0 行 vecs,而线程 1 处理第 1 行 vecs。通过这种方式读取查询数据,相邻的线程如线程 0 和线程 1 可以读取相邻的内存,从而实现内存合并以提高性能。
键#
类似于“查询”部分,本节介绍键的内存布局和分配。虽然每个线程组在一次内核运行中只处理一个查询令牌,但它可能在多次迭代中处理多个键令牌。同时,每个线程束将在多次迭代中处理多个键令牌块,确保在内核运行后整个线程组处理所有上下文令牌。在这种情况下,“处理”指的是在查询数据和键数据之间执行点乘运算。
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride + kv_head_idx * kv_head_stride + physical_block_offset * x;
与
q_ptr
不同,每个线程中的k_ptr
将在不同迭代中指向不同的键标记。如上所示,k_ptr
根据k_cache
指向分配块、分配头和分配标记中的键标记数据。上图展示了关键数据的内存布局。假设
BLOCK_SIZE
为 16,HEAD_SIZE
为 128,x
为 8,THREAD_GROUP_SIZE
为 2,总共有 4 个 warp。每个矩形代表一个头部的所有关键令牌元素,这些元素将由一个线程组处理。左半部分显示了 warp 0 的 16 个关键令牌数据块,而右半部分代表其他 warp 或迭代的剩余关键令牌数据。在每个矩形内部,总共有 32 个 vecs(一个令牌的 128 个元素)将由 2 个线程(一个线程组)分别处理。K_vec k_vecs[NUM_VECS_PER_THREAD]
接下来,我们需要从
k_ptr
读取关键令牌数据,并将它们存储在寄存器内存中作为k_vecs
。我们为k_vecs
使用寄存器内存,因为它们只会被一个线程访问一次,而q_vecs
将被多个线程多次访问。每个k_vecs
将包含多个向量用于后续计算。每个向量将在每次内部迭代中设置。向量的分配允许warp中的相邻线程一起读取相邻的内存,这再次促进了内存合并。例如,线程0将读取向量0,而线程1将读取向量1。在下一个内部循环中,线程0将读取向量2,而线程1将读取向量3,依此类推。您可能对整体流程还有些困惑。别担心,请继续阅读下一节“QK”。它将以更清晰、更高层次的方式说明查询和键的计算流程。
QK#
如下面伪代码所示,在整个for循环块之前,我们获取一个token的查询数据并将其存储在
q_vecs
中。然后,在外部for循环中,我们遍历指向不同token的k_ptrs
,并在内部for循环中准备k_vecs
。最后,我们在q_vecs
和每个k_vecs
之间执行点乘运算。q_vecs = ... for ... { k_ptr = ... for ... { k_vecs[i] = ... } ... float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs); }
如前所述,对于每个线程,它一次只获取部分查询和键标记数据。然而,在
Qk_dot<>::dot
中会发生跨线程组缩减。因此,这里返回的qk
不仅仅是部分查询和键标记点乘的结果,实际上是整个查询和键标记数据之间的完整结果。例如,如果
HEAD_SIZE
的值为 128 且THREAD_GROUP_SIZE
为 2,每个线程的k_vecs
将包含总共 64 个元素。然而,返回的qk
实际上是 128 个查询元素和 128 个键元素之间的点乘结果。如果你想了解更多关于点乘和归约的细节,可以参考Qk_dot<>::dot
的实现。不过,为了简单起见,本文档中将不涉及这些内容。
Softmax#
接下来,我们需要计算所有
qk
的归一化 softmax,如上所示,其中每个 \(x\) 代表一个qk
。为此,我们必须获取qk_max
(\(m(x)\)) 的归约值以及所有qk
的exp_sum
(\(\ell(x)\))。归约应在整个线程块中执行,涵盖查询令牌与所有上下文键令牌之间的结果。\begin{gather*} m(x):=\max _i \quad x_i \\ \quad f(x):=\left[\begin{array}{lll}e^{x_1-m(x)} & \ldots & e^{x_B-m(x)}\end{array}\right]\\ \quad \ell(x):=\sum_i f(x)_i \\ \quad \operatorname{softmax}(x):=\frac{f(x)}{\ell(x)} \end{gather*}
qk_max
和 logits
#
在我们得到
qk
结果后,我们可以用qk
设置临时的logits
结果(最终,logits
应该存储归一化的 softmax 结果)。同时,我们还可以比较并收集当前线程组计算的所有qk
的qk_max
。if (thread_group_offset == 0) { const bool mask = token_idx >= context_len; logits[token_idx - start_token_idx] = mask ? 0.f : qk; qk_max = mask ? qk_max : fmaxf(qk_max, qk); }
请注意,这里的
logits
是在共享内存中,因此每个线程组将为其分配的上下文令牌设置字段。总体而言,logits 的大小应为上下文令牌的数量。for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); } if (lane == 0) { red_smem[warp_idx] = qk_max; }
然后我们需要在每个warp中获取减少的
qk_max
。主要思想是让warp中的线程相互通信,并获取最终的最大qk
。for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); } qk_max = VLLM_SHFL_SYNC(qk_max, 0);
最后,我们可以通过比较线程块中所有warp的``qk_max``来获得整个线程块的简化``qk_max``。然后,我们需要将最终结果广播到每个线程。
exp_sum
#
类似于
qk_max
,我们也需要从整个线程块中获取缩减后的和值。for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { float val = __expf(logits[i] - qk_max); logits[i] = val; exp_sum += val; } ... exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
首先,将每个线程组中的所有 exp 值相加,同时,将
logits
的每个条目从qk
转换为exp(qk - qk_max)
。请注意,这里的qk_max
已经是整个线程块中的最大qk
。然后,我们可以像qk_max
一样对整个线程块的exp_sum
进行归约。const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { logits[i] *= inv_sum; }
最后,通过减少
qk_max
和exp_sum
,我们可以获得最终的归一化softmax结果,即logits
。这个logits
变量将在后续步骤中用于与值数据进行点乘。现在,它应该存储所有分配的上下文令牌的qk
的归一化softmax结果。
值#
现在我们需要检索值数据,并将其与
logits
进行点乘。与查询和键不同,值数据没有线程组的概念。如图所示,与键的标记内存布局不同,同一列的元素对应于相同的值标记。对于一个值数据块,有HEAD_SIZE
行和BLOCK_SIZE
列,这些列被分割成多个v_vecs
。每个线程总是从相同的
V_VEC_SIZE
个标记中一次获取V_VEC_SIZE
个元素。因此,单个线程通过多次内部迭代从不同行和相同列中检索多个v_vec
。对于每个v_vec
,它需要与相应的logits_vec
进行点乘,这也是来自logits
的V_VEC_SIZE
个元素。总体而言,通过多次内部迭代,每个线程块将处理一个值标记块。通过多次外部迭代,整个上下文值标记被处理。float accs[NUM_ROWS_PER_THREAD]; for ... { // Iteration over different blocks. logits_vec = ... for ... { // Iteration over different rows. v_vec = ... ... accs[i] += dot(logits_vec, v_vec); } }
如上伪代码所示,在外层循环中,类似于
k_ptr
,logits_vec
遍历不同的块并从logits
中读取V_VEC_SIZE
个元素。在内层循环中,每个线程读取与v_vec
相同的标记的V_VEC_SIZE
个元素并执行点乘运算。需要注意的是,在每次内层迭代中,线程为相同的标记获取不同的头部位置元素。点乘结果随后累加到accs
中。因此,accs
的每个条目都映射到当前线程分配的头部位置。例如,如果
BLOCK_SIZE
是 16 且V_VEC_SIZE
是 8,每个线程一次获取 8 个值元素用于 8 个标记。每个元素来自同一头部位置的不同标记。如果HEAD_SIZE
是 128 且WARP_SIZE
是 32,对于每个内部循环,一个 warp 需要获取WARP_SIZE * V_VEC_SIZE = 256
个元素。这意味着一个 warp 处理整个值标记块需要总共 128 * 16 / 256 = 8 次内部迭代。每个线程中的每个accs
包含 8 个元素,这些元素在 8 个不同的头部位置累积。对于线程 0,accs
变量将有 8 个元素,它们是从所有分配的 8 个标记中累积的第 0、32、…、224 个值头部元素。
LV#
现在,我们需要在每个warp内对
accs
进行归约。这个过程允许每个线程为分配到的所有token的头位置累积accs
。for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { float acc = accs[i]; for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { acc += VLLM_SHFL_XOR_SYNC(acc, mask); } accs[i] = acc; }
接下来,我们对所有线程中的
accs
进行归约,使得每个线程都能拥有所有上下文标记分配的头位置的accs
累加结果。请注意,每个线程中的accs
仅存储了所有上下文标记中整个头的一部分元素的累加结果。然而,总体上,所有输出的结果已经计算完毕,只是存储在不同的线程寄存器内存中。float* out_smem = reinterpret_cast<float*>(shared_mem); for (int i = NUM_WARPS; i > 1; i /= 2) { // Upper warps write to shared memory. ... float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { ... dst[row_idx] = accs[i]; } // Lower warps update the output. const float* src = &out_smem[warp_idx * HEAD_SIZE]; for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { ... accs[i] += src[row_idx]; } // Write out the accs. }
输出#
现在我们可以将本地寄存器内存中的所有计算结果写入最终的输出全局内存。
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
首先,我们需要定义
out_ptr
变量,该变量指向分配序列的起始地址和分配头。for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { from_float(*(out_ptr + row_idx), accs[i]); } }
最后,我们需要遍历不同的分配头位置,并根据
out_ptr
写出相应的累积结果。