import json
from http import HTTPStatus
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
import requests
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field
from requests import Response
from requests.exceptions import HTTPError
[docs]class MaritalkHTTPError(HTTPError):
def __init__(self, request_obj: Response) -> None:
self.request_obj = request_obj
try:
response_json = request_obj.json()
if "detail" in response_json:
api_message = response_json["detail"]
elif "message" in response_json:
api_message = response_json["message"]
else:
api_message = response_json
except Exception:
api_message = request_obj.text
self.message = api_message
self.status_code = request_obj.status_code
def __str__(self) -> str:
status_code_meaning = HTTPStatus(self.status_code).phrase
formatted_message = f"HTTP Error: {self.status_code} - {status_code_meaning}"
formatted_message += f"\nDetail: {self.message}"
return formatted_message
[docs]class ChatMaritalk(BaseChatModel):
"""`MariTalk` 聊天模型 API。
该类允许与 MariTalk 聊天机器人 API 进行交互。
要使用它,必须通过构造函数提供 API 密钥。
示例:
.. code-block:: python
from langchain_community.chat_models import ChatMaritalk
chat = ChatMaritalk(api_key="your_api_key_here")"""
api_key: str
"""您的 MariTalk API 密钥。"""
model: str
"""选择其中一个可用的模型:
- `sabia-2-medium`
- `sabia-2-small`
- `sabia-2-medium-2024-03-13`
- `sabia-2-small-2024-03-13`
- `maritalk-2024-01-08`(已弃用)"""
temperature: float = Field(default=0.7, gt=0.0, lt=1.0)
"""使用这个温度进行推断。
必须在闭区间[0.0, 1.0]内。"""
max_tokens: int = Field(default=512, gt=0)
"""回复中生成的最大令牌数。"""
do_sample: bool = Field(default=True)
"""是否使用抽样;使用 `True` 来启用。"""
top_p: float = Field(default=0.95, gt=0.0, lt=1.0)
"""核采样参数,控制考虑用于采样的概率质量的大小。"""
@property
def _llm_type(self) -> str:
"""将LLM类型标识为'maritalk'。"""
return "maritalk"
[docs] def parse_messages_for_model(
self, messages: List[BaseMessage]
) -> List[Dict[str, Union[str, List[Union[str, Dict[Any, Any]]]]]]:
"""将消息从LangChain的格式解析为MariTalk API期望的格式。
参数:
messages(List[BaseMessage]):要解析的LangChain格式消息列表。
返回:
一个为MariTalk API格式化的消息列表。
"""
parsed_messages = []
for message in messages:
if isinstance(message, HumanMessage):
role = "user"
elif isinstance(message, AIMessage):
role = "assistant"
elif isinstance(message, SystemMessage):
role = "system"
parsed_messages.append({"role": role, "content": message.content})
return parsed_messages
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""将解析后的消息发送到MariTalk API,并返回生成的响应或错误消息。
该方法使用提供的消息和其他参数向MariTalk API发出HTTP POST请求。
如果请求成功且API返回响应,则该方法返回包含答案的字符串。
如果请求被限速或遇到其他错误,则返回包含错误消息的字符串。
参数:
messages(List[BaseMessage]):要发送到模型的消息。
stop(Optional[List[str]):将触发模型停止生成进一步标记的标记。
返回:
str:如果API调用成功,则返回答案。
如果发生错误(例如,限速),则返回描述错误的字符串。
"""
url = "https://chat.maritaca.ai/api/chat/inference"
headers = {"authorization": f"Key {self.api_key}"}
stopping_tokens = stop if stop is not None else []
parsed_messages = self.parse_messages_for_model(messages)
data = {
"messages": parsed_messages,
"model": self.model,
"do_sample": self.do_sample,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"stopping_tokens": stopping_tokens,
**kwargs,
}
response = requests.post(url, json=data, headers=headers)
if response.ok:
return response.json().get("answer", "No answer found")
else:
raise MaritalkHTTPError(response)
async def _acall(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""异步地将解析后的消息发送到MariTalk API,并返回生成的响应或错误消息。
该方法使用异步I/O向MariTalk API发出HTTP POST请求,使用提供的消息和其他参数。
如果请求成功且API返回响应,则该方法返回包含答案的字符串。
如果请求被限速或遇到其他错误,则返回包含错误消息的字符串。
"""
try:
import httpx
url = "https://chat.maritaca.ai/api/chat/inference"
headers = {"authorization": f"Key {self.api_key}"}
stopping_tokens = stop if stop is not None else []
parsed_messages = self.parse_messages_for_model(messages)
data = {
"messages": parsed_messages,
"model": self.model,
"do_sample": self.do_sample,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"stopping_tokens": stopping_tokens,
**kwargs,
}
async with httpx.AsyncClient() as client:
response = await client.post(
url, json=data, headers=headers, timeout=None
)
if response.status_code == 200:
return response.json().get("answer", "No answer found")
else:
raise MaritalkHTTPError(response)
except ImportError:
raise ImportError(
"Could not import httpx python package. "
"Please install it with `pip install httpx`."
)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
headers = {"Authorization": f"Key {self.api_key}"}
stopping_tokens = stop if stop is not None else []
parsed_messages = self.parse_messages_for_model(messages)
data = {
"messages": parsed_messages,
"model": self.model,
"do_sample": self.do_sample,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"stopping_tokens": stopping_tokens,
"stream": True,
**kwargs,
}
response = requests.post(
"https://chat.maritaca.ai/api/chat/inference",
data=json.dumps(data),
headers=headers,
stream=True,
)
if response.ok:
for line in response.iter_lines():
if line.startswith(b"data: "):
response_data = line.replace(b"data: ", b"").decode("utf-8")
if response_data:
parsed_data = json.loads(response_data)
if "text" in parsed_data:
delta = parsed_data["text"]
chunk = ChatGenerationChunk(
message=AIMessageChunk(content=delta)
)
if run_manager:
run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk
else:
raise MaritalkHTTPError(response)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
try:
import httpx
headers = {"Authorization": f"Key {self.api_key}"}
stopping_tokens = stop if stop is not None else []
parsed_messages = self.parse_messages_for_model(messages)
data = {
"messages": parsed_messages,
"model": self.model,
"do_sample": self.do_sample,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"stopping_tokens": stopping_tokens,
"stream": True,
**kwargs,
}
async with httpx.AsyncClient() as client:
async with client.stream(
"POST",
"https://chat.maritaca.ai/api/chat/inference",
data=json.dumps(data),
headers=headers,
timeout=None,
) as response:
if response.status_code == 200:
async for line in response.aiter_lines():
if line.startswith("data: "):
response_data = line.replace("data: ", "")
if response_data:
parsed_data = json.loads(response_data)
if "text" in parsed_data:
delta = parsed_data["text"]
chunk = ChatGenerationChunk(
message=AIMessageChunk(content=delta)
)
if run_manager:
await run_manager.on_llm_new_token(
delta, chunk=chunk
)
yield chunk
else:
raise MaritalkHTTPError(response)
except ImportError:
raise ImportError(
"Could not import httpx python package. "
"Please install it with `pip install httpx`."
)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
output_str = await self._acall(
messages, stop=stop, run_manager=run_manager, **kwargs
)
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@property
def _identifying_params(self) -> Dict[str, Any]:
"""识别聊天模型的关键参数,用于记录或跟踪目的。
返回:
关键配置参数的字典。
"""
return {
"model": self.model,
"temperature": self.temperature,
"top_p": self.top_p,
"max_tokens": self.max_tokens,
}