离线推理音频语言#
源代码 vllm-project/vllm。
1"""
2This example shows how to use vLLM for running offline inference
3with the correct prompt format on audio language models.
4
5For most models, the prompt format should follow corresponding examples
6on HuggingFace model repository.
7"""
8from transformers import AutoTokenizer
9
10from vllm import LLM, SamplingParams
11from vllm.assets.audio import AudioAsset
12from vllm.utils import FlexibleArgumentParser
13
14audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
15question_per_audio_count = [
16 "What is recited in the audio?",
17 "What sport and what nursery rhyme are referenced?"
18]
19
20
21# Ultravox 0.3
22def run_ultravox(question, audio_count):
23 model_name = "fixie-ai/ultravox-v0_3"
24
25 tokenizer = AutoTokenizer.from_pretrained(model_name)
26 messages = [{
27 'role':
28 'user',
29 'content':
30 "<|reserved_special_token_0|>\n" * audio_count + question
31 }]
32 prompt = tokenizer.apply_chat_template(messages,
33 tokenize=False,
34 add_generation_prompt=True)
35
36 llm = LLM(model=model_name,
37 enforce_eager=True,
38 enable_chunked_prefill=False,
39 max_model_len=8192,
40 limit_mm_per_prompt={"audio": audio_count})
41 stop_token_ids = None
42 return llm, prompt, stop_token_ids
43
44
45model_example_map = {
46 "ultravox": run_ultravox,
47}
48
49
50def main(args):
51 model = args.model_type
52 if model not in model_example_map:
53 raise ValueError(f"Model type {model} is not supported.")
54
55 audio_count = args.num_audios
56 llm, prompt, stop_token_ids = model_example_map[model](
57 question_per_audio_count[audio_count - 1], audio_count)
58
59 # We set temperature to 0.2 so that outputs can be different
60 # even when all prompts are identical when running batch inference.
61 sampling_params = SamplingParams(temperature=0.2,
62 max_tokens=64,
63 stop_token_ids=stop_token_ids)
64
65 assert args.num_prompts > 0
66 inputs = {
67 "prompt": prompt,
68 "multi_modal_data": {
69 "audio": [
70 asset.audio_and_sample_rate
71 for asset in audio_assets[:audio_count]
72 ]
73 },
74 }
75 if args.num_prompts > 1:
76 # Batch inference
77 inputs = [inputs] * args.num_prompts
78
79 outputs = llm.generate(inputs, sampling_params=sampling_params)
80
81 for o in outputs:
82 generated_text = o.outputs[0].text
83 print(generated_text)
84
85
86if __name__ == "__main__":
87 parser = FlexibleArgumentParser(
88 description='Demo on using vLLM for offline inference with '
89 'audio language models')
90 parser.add_argument('--model-type',
91 '-m',
92 type=str,
93 default="ultravox",
94 choices=model_example_map.keys(),
95 help='Huggingface "model_type".')
96 parser.add_argument('--num-prompts',
97 type=int,
98 default=1,
99 help='Number of prompts to run.')
100 parser.add_argument("--num-audios",
101 type=int,
102 default=1,
103 choices=[1, 2],
104 help="Number of audio items per prompt.")
105
106 args = parser.parse_args()
107 main(args)