离线推理 Mlpspeculator

离线推理 Mlpspeculator#

源代码 vllm-project/vllm

 1import gc
 2import time
 3from typing import List
 4
 5from vllm import LLM, SamplingParams
 6
 7
 8def time_generation(llm: LLM, prompts: List[str],
 9                    sampling_params: SamplingParams):
10    # Generate texts from the prompts. The output is a list of RequestOutput
11    # objects that contain the prompt, generated text, and other information.
12    # Warmup first
13    llm.generate(prompts, sampling_params)
14    llm.generate(prompts, sampling_params)
15    start = time.time()
16    outputs = llm.generate(prompts, sampling_params)
17    end = time.time()
18    print((end - start) / sum([len(o.outputs[0].token_ids) for o in outputs]))
19    # Print the outputs.
20    for output in outputs:
21        generated_text = output.outputs[0].text
22        print(f"text: {generated_text!r}")
23
24
25if __name__ == "__main__":
26
27    template = (
28        "Below is an instruction that describes a task. Write a response "
29        "that appropriately completes the request.\n\n### Instruction:\n{}"
30        "\n\n### Response:\n")
31
32    # Sample prompts.
33    prompts = [
34        "Write about the president of the United States.",
35    ]
36    prompts = [template.format(prompt) for prompt in prompts]
37    # Create a sampling params object.
38    sampling_params = SamplingParams(temperature=0.0, max_tokens=200)
39
40    # Create an LLM without spec decoding
41    llm = LLM(model="meta-llama/Llama-2-13b-chat-hf")
42
43    print("Without speculation")
44    time_generation(llm, prompts, sampling_params)
45
46    del llm
47    gc.collect()
48
49    # Create an LLM with spec decoding
50    llm = LLM(
51        model="meta-llama/Llama-2-13b-chat-hf",
52        speculative_model="ibm-fms/llama-13b-accelerator",
53        # These are currently required for MLPSpeculator decoding
54        use_v2_block_manager=True,
55    )
56
57    print("With speculation")
58    time_generation(llm, prompts, sampling_params)