Lora 带量化推理

Lora 带量化推理#

源代码 vllm-project/vllm

  1"""
  2This example shows how to use LoRA with different quantization techniques
  3for offline inference.
  4
  5Requires HuggingFace credentials for access.
  6"""
  7
  8import gc
  9from typing import List, Optional, Tuple
 10
 11import torch
 12from huggingface_hub import snapshot_download
 13
 14from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
 15from vllm.lora.request import LoRARequest
 16
 17
 18def create_test_prompts(
 19        lora_path: str
 20) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
 21    return [
 22        # this is an example of using quantization without LoRA
 23        ("My name is",
 24         SamplingParams(temperature=0.0,
 25                        logprobs=1,
 26                        prompt_logprobs=1,
 27                        max_tokens=128), None),
 28        # the next three examples use quantization with LoRA
 29        ("my name is",
 30         SamplingParams(temperature=0.0,
 31                        logprobs=1,
 32                        prompt_logprobs=1,
 33                        max_tokens=128),
 34         LoRARequest("lora-test-1", 1, lora_path)),
 35        ("The capital of USA is",
 36         SamplingParams(temperature=0.0,
 37                        logprobs=1,
 38                        prompt_logprobs=1,
 39                        max_tokens=128),
 40         LoRARequest("lora-test-2", 1, lora_path)),
 41        ("The capital of France is",
 42         SamplingParams(temperature=0.0,
 43                        logprobs=1,
 44                        prompt_logprobs=1,
 45                        max_tokens=128),
 46         LoRARequest("lora-test-3", 1, lora_path)),
 47    ]
 48
 49
 50def process_requests(engine: LLMEngine,
 51                     test_prompts: List[Tuple[str, SamplingParams,
 52                                              Optional[LoRARequest]]]):
 53    """Continuously process a list of prompts and handle the outputs."""
 54    request_id = 0
 55
 56    while test_prompts or engine.has_unfinished_requests():
 57        if test_prompts:
 58            prompt, sampling_params, lora_request = test_prompts.pop(0)
 59            engine.add_request(str(request_id),
 60                               prompt,
 61                               sampling_params,
 62                               lora_request=lora_request)
 63            request_id += 1
 64
 65        request_outputs: List[RequestOutput] = engine.step()
 66        for request_output in request_outputs:
 67            if request_output.finished:
 68                print("----------------------------------------------------")
 69                print(f"Prompt: {request_output.prompt}")
 70                print(f"Output: {request_output.outputs[0].text}")
 71
 72
 73def initialize_engine(model: str, quantization: str,
 74                      lora_repo: Optional[str]) -> LLMEngine:
 75    """Initialize the LLMEngine."""
 76
 77    if quantization == "bitsandbytes":
 78        # QLoRA (https://arxiv.org/abs/2305.14314) is a quantization technique.
 79        # It quantizes the model when loading, with some config info from the
 80        # LoRA adapter repo. So need to set the parameter of load_format and
 81        # qlora_adapter_name_or_path as below.
 82        engine_args = EngineArgs(model=model,
 83                                 quantization=quantization,
 84                                 qlora_adapter_name_or_path=lora_repo,
 85                                 load_format="bitsandbytes",
 86                                 enable_lora=True,
 87                                 max_lora_rank=64)
 88    else:
 89        engine_args = EngineArgs(model=model,
 90                                 quantization=quantization,
 91                                 enable_lora=True,
 92                                 max_loras=4)
 93    return LLMEngine.from_engine_args(engine_args)
 94
 95
 96def main():
 97    """Main function that sets up and runs the prompt processing."""
 98
 99    test_configs = [{
100        "name": "qlora_inference_example",
101        'model': "huggyllama/llama-7b",
102        'quantization': "bitsandbytes",
103        'lora_repo': 'timdettmers/qlora-flan-7b'
104    }, {
105        "name": "AWQ_inference_with_lora_example",
106        'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ',
107        'quantization': "awq",
108        'lora_repo': 'jashing/tinyllama-colorist-lora'
109    }, {
110        "name": "GPTQ_inference_with_lora_example",
111        'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ',
112        'quantization': "gptq",
113        'lora_repo': 'jashing/tinyllama-colorist-lora'
114    }]
115
116    for test_config in test_configs:
117        print(
118            f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~"
119        )
120        engine = initialize_engine(test_config['model'],
121                                   test_config['quantization'],
122                                   test_config['lora_repo'])
123        lora_path = snapshot_download(repo_id=test_config['lora_repo'])
124        test_prompts = create_test_prompts(lora_path)
125        process_requests(engine, test_prompts)
126
127        # Clean up the GPU memory for the next test
128        del engine
129        gc.collect()
130        torch.cuda.empty_cache()
131
132
133if __name__ == '__main__':
134    main()