vLLM 中的推测性解码#
警告
请注意,vLLM 中的推测性解码尚未优化,并且通常不会为所有提示数据集或采样参数带来令牌间延迟的减少。优化工作正在进行中,可以在 此问题 中跟进。
本文档展示了如何使用 推测性解码 与 vLLM。推测性解码是一种提高内存受限 LLM 推理中令牌间延迟的技术。
使用草稿模型进行推测#
以下代码在离线模式下配置 vLLM,使用带有草稿模型的推测解码,每次推测 5 个标记。
from vllm import LLM, SamplingParams
prompts = [
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="facebook/opt-6.7b",
tensor_parallel_size=1,
speculative_model="facebook/opt-125m",
num_speculative_tokens=5,
use_v2_block_manager=True,
)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
要在在线模式下执行相同的操作,请启动服务器:
python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 --model facebook/opt-6.7b \
--seed 42 -tp 1 --speculative_model facebook/opt-125m --use-v2-block-manager \
--num_speculative_tokens 5 --gpu_memory_utilization 0.8
Then use a client:
from openai import OpenAI
# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model = models.data[0].id
# Completion API
stream = False
completion = client.completions.create(
model=model,
prompt="The future of AI is",
echo=False,
n=1,
stream=stream,
)
print("Completion results:")
if stream:
for c in completion:
print(c)
else:
print(completion)
通过匹配提示中的n-gram进行推测#
以下代码配置 vLLM 使用推测解码,其中提案是通过匹配提示中的 n-gram 生成的。更多信息请阅读 这个帖子。
from vllm import LLM, SamplingParams
prompts = [
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="facebook/opt-6.7b",
tensor_parallel_size=1,
speculative_model="[ngram]",
num_speculative_tokens=5,
ngram_prompt_lookup_max=4,
use_v2_block_manager=True,
)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
使用 MLP 投机者进行投机#
以下代码配置 vLLM 使用推测解码,其中提案由草稿模型生成,这些模型根据上下文向量和采样令牌生成草稿预测。更多信息请参见 这篇博客 或 这篇技术报告。
from vllm import LLM, SamplingParams
prompts = [
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="meta-llama/Meta-Llama-3.1-70B-Instruct",
tensor_parallel_size=4,
speculative_model="ibm-fms/llama3-70b-accelerator",
speculative_draft_tensor_parallel_size=1,
use_v2_block_manager=True,
)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
请注意,这些推测性模型目前需要在没有张量并行的情况下运行,尽管可以使用张量并行运行主模型(见上例)。由于推测性模型相对较小,我们仍然看到显著的加速。然而,这一限制将在未来的版本中修复。
HF hub 上提供了多种此类型的推测模型:
推测解码的无损保证#
在vLLM中,推测性解码旨在提高推理效率的同时保持准确性。本节讨论推测性解码的无损保证,将这些保证分解为三个关键领域:
理论上的无损性 - 推测性解码采样在硬件数值精度限制内理论上是无损的。浮点误差可能会导致输出分布出现轻微变化,如 Accelerating Large Language Model Decoding with Speculative Sampling 中所讨论的。
算法无损性 - vLLM 的推测解码实现经过算法验证,确保无损。关键验证测试包括:
拒绝采样器收敛性:确保 vLLM 的拒绝采样器生成的样本与目标分布一致。查看测试代码
贪婪采样等价性: 确认带有推测解码的贪婪采样与不带推测解码的贪婪采样相匹配。这验证了vLLM的推测解码框架,当与vLLM前向传递和vLLM拒绝采样器集成时,提供了无损保证。此目录 中的几乎所有测试都使用`此断言实现 <vllm-project/vllm>`_ 来验证此属性。
vLLM Logprob 稳定性 - vLLM 目前不能保证稳定的 token 对数概率(logprobs)。这可能导致在不同运行中对同一请求产生不同的输出。更多详情,请参见 常见问题解答 中的 vLLM 中同一提示的输出在不同运行中会有所不同吗? 部分。
结论
虽然 vLLM 努力确保在推测性解码中的无损性,但由于以下因素,使用和不使用推测性解码生成的输出可能会有所不同:
浮点精度: 硬件数值精度的差异可能导致输出分布出现细微差异。
批量大小和数值稳定性: 批量大小的变化可能导致对数概率和输出概率的变化,这可能是由于批处理操作中的非确定性行为或数值不稳定性引起的。
缓解策略
对于缓解策略,请参阅 常见问题解答 中的条目 vLLM 中的提示输出在不同运行中是否会有所不同?。