Lora 带量化推理#
源代码 vllm-project/vllm。
1"""
2This example shows how to use LoRA with different quantization techniques
3for offline inference.
4
5Requires HuggingFace credentials for access.
6"""
7
8import gc
9from typing import List, Optional, Tuple
10
11import torch
12from huggingface_hub import snapshot_download
13
14from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
15from vllm.lora.request import LoRARequest
16
17
18def create_test_prompts(
19 lora_path: str
20) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
21 return [
22 # this is an example of using quantization without LoRA
23 ("My name is",
24 SamplingParams(temperature=0.0,
25 logprobs=1,
26 prompt_logprobs=1,
27 max_tokens=128), None),
28 # the next three examples use quantization with LoRA
29 ("my name is",
30 SamplingParams(temperature=0.0,
31 logprobs=1,
32 prompt_logprobs=1,
33 max_tokens=128),
34 LoRARequest("lora-test-1", 1, lora_path)),
35 ("The capital of USA is",
36 SamplingParams(temperature=0.0,
37 logprobs=1,
38 prompt_logprobs=1,
39 max_tokens=128),
40 LoRARequest("lora-test-2", 1, lora_path)),
41 ("The capital of France is",
42 SamplingParams(temperature=0.0,
43 logprobs=1,
44 prompt_logprobs=1,
45 max_tokens=128),
46 LoRARequest("lora-test-3", 1, lora_path)),
47 ]
48
49
50def process_requests(engine: LLMEngine,
51 test_prompts: List[Tuple[str, SamplingParams,
52 Optional[LoRARequest]]]):
53 """Continuously process a list of prompts and handle the outputs."""
54 request_id = 0
55
56 while test_prompts or engine.has_unfinished_requests():
57 if test_prompts:
58 prompt, sampling_params, lora_request = test_prompts.pop(0)
59 engine.add_request(str(request_id),
60 prompt,
61 sampling_params,
62 lora_request=lora_request)
63 request_id += 1
64
65 request_outputs: List[RequestOutput] = engine.step()
66 for request_output in request_outputs:
67 if request_output.finished:
68 print("----------------------------------------------------")
69 print(f"Prompt: {request_output.prompt}")
70 print(f"Output: {request_output.outputs[0].text}")
71
72
73def initialize_engine(model: str, quantization: str,
74 lora_repo: Optional[str]) -> LLMEngine:
75 """Initialize the LLMEngine."""
76
77 if quantization == "bitsandbytes":
78 # QLoRA (https://arxiv.org/abs/2305.14314) is a quantization technique.
79 # It quantizes the model when loading, with some config info from the
80 # LoRA adapter repo. So need to set the parameter of load_format and
81 # qlora_adapter_name_or_path as below.
82 engine_args = EngineArgs(model=model,
83 quantization=quantization,
84 qlora_adapter_name_or_path=lora_repo,
85 load_format="bitsandbytes",
86 enable_lora=True,
87 max_lora_rank=64)
88 else:
89 engine_args = EngineArgs(model=model,
90 quantization=quantization,
91 enable_lora=True,
92 max_loras=4)
93 return LLMEngine.from_engine_args(engine_args)
94
95
96def main():
97 """Main function that sets up and runs the prompt processing."""
98
99 test_configs = [{
100 "name": "qlora_inference_example",
101 'model': "huggyllama/llama-7b",
102 'quantization': "bitsandbytes",
103 'lora_repo': 'timdettmers/qlora-flan-7b'
104 }, {
105 "name": "AWQ_inference_with_lora_example",
106 'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ',
107 'quantization': "awq",
108 'lora_repo': 'jashing/tinyllama-colorist-lora'
109 }, {
110 "name": "GPTQ_inference_with_lora_example",
111 'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ',
112 'quantization': "gptq",
113 'lora_repo': 'jashing/tinyllama-colorist-lora'
114 }]
115
116 for test_config in test_configs:
117 print(
118 f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~"
119 )
120 engine = initialize_engine(test_config['model'],
121 test_config['quantization'],
122 test_config['lora_repo'])
123 lora_path = snapshot_download(repo_id=test_config['lora_repo'])
124 test_prompts = create_test_prompts(lora_path)
125 process_requests(engine, test_prompts)
126
127 # Clean up the GPU memory for the next test
128 del engine
129 gc.collect()
130 torch.cuda.empty_cache()
131
132
133if __name__ == '__main__':
134 main()