FP8 E4M3 KV 缓存

FP8 E4M3 KV 缓存#

将 KV 缓存量化为 FP8 可以减少其内存占用。这增加了可以在缓存中存储的令牌数量,从而提高了吞吐量。OCP(开放计算项目 www.opencompute.org)指定了两种常见的 8 位浮点数据格式:E5M2(5 个指数位和 2 个尾数位)和 E4M3FN(4 个指数位和 3 个尾数位),通常简称为 E4M3。与 E5M2 相比,E4M3 格式的一个优点是浮点数的表示精度更高。然而,FP8 E4M3 的小动态范围(±240.0 可以表示)通常需要与每个量化张量一起使用更高精度的(通常是 FP32)缩放因子。目前,仅支持每个张量(标量)缩放因子。开发正在进行中,以支持更细粒度的缩放因子(例如,每个通道)。

这些缩放因子可以通过在加载时向LLM引擎传递一个可选的量化参数JSON来指定。如果没有指定这个JSON,缩放因子默认为1.0。这些缩放因子通常在通过量化工具(例如AMD量化器或NVIDIA AMMO)运行未量化的模型时获得。

要安装 AMMO(算法模型优化):

$ pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-ammo

研究表明,FP8 E4M3 量化通常只会轻微降低推理精度。最新的硅产品,例如 AMD MI300、NVIDIA Hopper 或更新的产品,支持原生硬件转换为 fp32、fp16、bf16 等。因此,LLM 推理在最小精度损失的情况下大大加速。

以下是如何启用此功能的示例:

# two float8_e4m3fn kv cache scaling factor files are provided under tests/fp8_kv, please refer to
# https://github.com/vllm-project/vllm/blob/main/examples/fp8/README.md to generate kv_cache_scales.json of your own.

from vllm import LLM, SamplingParams
sampling_params = SamplingParams(temperature=1.3, top_p=0.8)
llm = LLM(model="meta-llama/Llama-2-7b-chat-hf",
          kv_cache_dtype="fp8",
          quantization_param_path="./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json")
prompt = "London is the capital of"
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
print(out)

# output w/ scaling factors:  England, the United Kingdom, and one of the world's leading financial,
# output w/o scaling factors:  England, located in the southeastern part of the country. It is known