带有工具的 OpenAI 聊天完成客户端

带有工具的 OpenAI 聊天完成客户端#

源代码 vllm-project/vllm

  1"""
  2Set up this example by starting a vLLM OpenAI-compatible server with tool call
  3options enabled. For example:
  4
  5IMPORTANT: for mistral, you must use one of the provided mistral tool call
  6templates, or your own - the model default doesn't work for tool calls with vLLM
  7See the vLLM docs on OpenAI server & tool calling for more details.
  8
  9vllm serve --model mistralai/Mistral-7B-Instruct-v0.3 \
 10            --chat-template examples/tool_chat_template_mistral.jinja \
 11            --enable-auto-tool-choice --tool-call-parser mistral
 12
 13OR
 14vllm serve --model NousResearch/Hermes-2-Pro-Llama-3-8B \
 15            --chat-template examples/tool_chat_template_hermes.jinja \
 16            --enable-auto-tool-choice --tool-call-parser hermes
 17"""
 18import json
 19
 20from openai import OpenAI
 21
 22# Modify OpenAI's API key and API base to use vLLM's API server.
 23openai_api_key = "EMPTY"
 24openai_api_base = "http://localhost:8000/v1"
 25
 26client = OpenAI(
 27    # defaults to os.environ.get("OPENAI_API_KEY")
 28    api_key=openai_api_key,
 29    base_url=openai_api_base,
 30)
 31
 32models = client.models.list()
 33model = models.data[0].id
 34
 35tools = [{
 36    "type": "function",
 37    "function": {
 38        "name": "get_current_weather",
 39        "description": "Get the current weather in a given location",
 40        "parameters": {
 41            "type": "object",
 42            "properties": {
 43                "city": {
 44                    "type":
 45                    "string",
 46                    "description":
 47                    "The city to find the weather for, e.g. 'San Francisco'"
 48                },
 49                "state": {
 50                    "type":
 51                    "string",
 52                    "description":
 53                    "the two-letter abbreviation for the state that the city is"
 54                    " in, e.g. 'CA' which would mean 'California'"
 55                },
 56                "unit": {
 57                    "type": "string",
 58                    "description": "The unit to fetch the temperature in",
 59                    "enum": ["celsius", "fahrenheit"]
 60                }
 61            },
 62            "required": ["city", "state", "unit"]
 63        }
 64    }
 65}]
 66
 67messages = [{
 68    "role": "user",
 69    "content": "Hi! How are you doing today?"
 70}, {
 71    "role": "assistant",
 72    "content": "I'm doing well! How can I help you?"
 73}, {
 74    "role":
 75    "user",
 76    "content":
 77    "Can you tell me what the temperate will be in Dallas, in fahrenheit?"
 78}]
 79
 80chat_completion = client.chat.completions.create(messages=messages,
 81                                                 model=model,
 82                                                 tools=tools)
 83
 84print("Chat completion results:")
 85print(chat_completion)
 86print("\n\n")
 87
 88tool_calls_stream = client.chat.completions.create(messages=messages,
 89                                                   model=model,
 90                                                   tools=tools,
 91                                                   stream=True)
 92
 93chunks = []
 94for chunk in tool_calls_stream:
 95    chunks.append(chunk)
 96    if chunk.choices[0].delta.tool_calls:
 97        print(chunk.choices[0].delta.tool_calls[0])
 98    else:
 99        print(chunk.choices[0].delta)
100
101arguments = []
102tool_call_idx = -1
103for chunk in chunks:
104
105    if chunk.choices[0].delta.tool_calls:
106        tool_call = chunk.choices[0].delta.tool_calls[0]
107
108        if tool_call.index != tool_call_idx:
109            if tool_call_idx >= 0:
110                print(
111                    f"streamed tool call arguments: {arguments[tool_call_idx]}"
112                )
113            tool_call_idx = chunk.choices[0].delta.tool_calls[0].index
114            arguments.append("")
115        if tool_call.id:
116            print(f"streamed tool call id: {tool_call.id} ")
117
118        if tool_call.function:
119            if tool_call.function.name:
120                print(f"streamed tool call name: {tool_call.function.name}")
121
122            if tool_call.function.arguments:
123                arguments[tool_call_idx] += tool_call.function.arguments
124
125if len(arguments):
126    print(f"streamed tool call arguments: {arguments[-1]}")
127
128print("\n\n")
129
130messages.append({
131    "role": "assistant",
132    "tool_calls": chat_completion.choices[0].message.tool_calls
133})
134
135
136# Now, simulate a tool call
137def get_current_weather(city: str, state: str, unit: 'str'):
138    return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is "
139            "partly cloudly, with highs in the 90's.")
140
141
142available_tools = {"get_current_weather": get_current_weather}
143
144completion_tool_calls = chat_completion.choices[0].message.tool_calls
145for call in completion_tool_calls:
146    tool_to_call = available_tools[call.function.name]
147    args = json.loads(call.function.arguments)
148    result = tool_to_call(**args)
149    print(result)
150    messages.append({
151        "role": "tool",
152        "content": result,
153        "tool_call_id": call.id,
154        "name": call.function.name
155    })
156
157chat_completion_2 = client.chat.completions.create(messages=messages,
158                                                   model=model,
159                                                   tools=tools,
160                                                   stream=False)
161print("\n\n")
162print(chat_completion_2)