Source code for langchain_community.chat_models.maritalk

import json
from http import HTTPStatus
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union

import requests
from langchain_core.callbacks import (
    AsyncCallbackManagerForLLMRun,
    CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
    AIMessage,
    AIMessageChunk,
    BaseMessage,
    HumanMessage,
    SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field
from requests import Response
from requests.exceptions import HTTPError


[docs]class MaritalkHTTPError(HTTPError): def __init__(self, request_obj: Response) -> None: self.request_obj = request_obj try: response_json = request_obj.json() if "detail" in response_json: api_message = response_json["detail"] elif "message" in response_json: api_message = response_json["message"] else: api_message = response_json except Exception: api_message = request_obj.text self.message = api_message self.status_code = request_obj.status_code def __str__(self) -> str: status_code_meaning = HTTPStatus(self.status_code).phrase formatted_message = f"HTTP Error: {self.status_code} - {status_code_meaning}" formatted_message += f"\nDetail: {self.message}" return formatted_message
[docs]class ChatMaritalk(BaseChatModel): """`MariTalk` 聊天模型 API。 该类允许与 MariTalk 聊天机器人 API 进行交互。 要使用它,必须通过构造函数提供 API 密钥。 示例: .. code-block:: python from langchain_community.chat_models import ChatMaritalk chat = ChatMaritalk(api_key="your_api_key_here")""" api_key: str """您的 MariTalk API 密钥。""" model: str """选择其中一个可用的模型: - `sabia-2-medium` - `sabia-2-small` - `sabia-2-medium-2024-03-13` - `sabia-2-small-2024-03-13` - `maritalk-2024-01-08`(已弃用)""" temperature: float = Field(default=0.7, gt=0.0, lt=1.0) """使用这个温度进行推断。 必须在闭区间[0.0, 1.0]内。""" max_tokens: int = Field(default=512, gt=0) """回复中生成的最大令牌数。""" do_sample: bool = Field(default=True) """是否使用抽样;使用 `True` 来启用。""" top_p: float = Field(default=0.95, gt=0.0, lt=1.0) """核采样参数,控制考虑用于采样的概率质量的大小。""" @property def _llm_type(self) -> str: """将LLM类型标识为'maritalk'。""" return "maritalk"
[docs] def parse_messages_for_model( self, messages: List[BaseMessage] ) -> List[Dict[str, Union[str, List[Union[str, Dict[Any, Any]]]]]]: """将消息从LangChain的格式解析为MariTalk API期望的格式。 参数: messages(List[BaseMessage]):要解析的LangChain格式消息列表。 返回: 一个为MariTalk API格式化的消息列表。 """ parsed_messages = [] for message in messages: if isinstance(message, HumanMessage): role = "user" elif isinstance(message, AIMessage): role = "assistant" elif isinstance(message, SystemMessage): role = "system" parsed_messages.append({"role": role, "content": message.content}) return parsed_messages
def _call( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """将解析后的消息发送到MariTalk API,并返回生成的响应或错误消息。 该方法使用提供的消息和其他参数向MariTalk API发出HTTP POST请求。 如果请求成功且API返回响应,则该方法返回包含答案的字符串。 如果请求被限速或遇到其他错误,则返回包含错误消息的字符串。 参数: messages(List[BaseMessage]):要发送到模型的消息。 stop(Optional[List[str]):将触发模型停止生成进一步标记的标记。 返回: str:如果API调用成功,则返回答案。 如果发生错误(例如,限速),则返回描述错误的字符串。 """ url = "https://chat.maritaca.ai/api/chat/inference" headers = {"authorization": f"Key {self.api_key}"} stopping_tokens = stop if stop is not None else [] parsed_messages = self.parse_messages_for_model(messages) data = { "messages": parsed_messages, "model": self.model, "do_sample": self.do_sample, "max_tokens": self.max_tokens, "temperature": self.temperature, "top_p": self.top_p, "stopping_tokens": stopping_tokens, **kwargs, } response = requests.post(url, json=data, headers=headers) if response.ok: return response.json().get("answer", "No answer found") else: raise MaritalkHTTPError(response) async def _acall( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """异步地将解析后的消息发送到MariTalk API,并返回生成的响应或错误消息。 该方法使用异步I/O向MariTalk API发出HTTP POST请求,使用提供的消息和其他参数。 如果请求成功且API返回响应,则该方法返回包含答案的字符串。 如果请求被限速或遇到其他错误,则返回包含错误消息的字符串。 """ try: import httpx url = "https://chat.maritaca.ai/api/chat/inference" headers = {"authorization": f"Key {self.api_key}"} stopping_tokens = stop if stop is not None else [] parsed_messages = self.parse_messages_for_model(messages) data = { "messages": parsed_messages, "model": self.model, "do_sample": self.do_sample, "max_tokens": self.max_tokens, "temperature": self.temperature, "top_p": self.top_p, "stopping_tokens": stopping_tokens, **kwargs, } async with httpx.AsyncClient() as client: response = await client.post( url, json=data, headers=headers, timeout=None ) if response.status_code == 200: return response.json().get("answer", "No answer found") else: raise MaritalkHTTPError(response) except ImportError: raise ImportError( "Could not import httpx python package. " "Please install it with `pip install httpx`." ) def _stream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: headers = {"Authorization": f"Key {self.api_key}"} stopping_tokens = stop if stop is not None else [] parsed_messages = self.parse_messages_for_model(messages) data = { "messages": parsed_messages, "model": self.model, "do_sample": self.do_sample, "max_tokens": self.max_tokens, "temperature": self.temperature, "top_p": self.top_p, "stopping_tokens": stopping_tokens, "stream": True, **kwargs, } response = requests.post( "https://chat.maritaca.ai/api/chat/inference", data=json.dumps(data), headers=headers, stream=True, ) if response.ok: for line in response.iter_lines(): if line.startswith(b"data: "): response_data = line.replace(b"data: ", b"").decode("utf-8") if response_data: parsed_data = json.loads(response_data) if "text" in parsed_data: delta = parsed_data["text"] chunk = ChatGenerationChunk( message=AIMessageChunk(content=delta) ) if run_manager: run_manager.on_llm_new_token(delta, chunk=chunk) yield chunk else: raise MaritalkHTTPError(response) async def _astream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: try: import httpx headers = {"Authorization": f"Key {self.api_key}"} stopping_tokens = stop if stop is not None else [] parsed_messages = self.parse_messages_for_model(messages) data = { "messages": parsed_messages, "model": self.model, "do_sample": self.do_sample, "max_tokens": self.max_tokens, "temperature": self.temperature, "top_p": self.top_p, "stopping_tokens": stopping_tokens, "stream": True, **kwargs, } async with httpx.AsyncClient() as client: async with client.stream( "POST", "https://chat.maritaca.ai/api/chat/inference", data=json.dumps(data), headers=headers, timeout=None, ) as response: if response.status_code == 200: async for line in response.aiter_lines(): if line.startswith("data: "): response_data = line.replace("data: ", "") if response_data: parsed_data = json.loads(response_data) if "text" in parsed_data: delta = parsed_data["text"] chunk = ChatGenerationChunk( message=AIMessageChunk(content=delta) ) if run_manager: await run_manager.on_llm_new_token( delta, chunk=chunk ) yield chunk else: raise MaritalkHTTPError(response) except ImportError: raise ImportError( "Could not import httpx python package. " "Please install it with `pip install httpx`." ) def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs) message = AIMessage(content=output_str) generation = ChatGeneration(message=message) return ChatResult(generations=[generation]) async def _agenerate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: output_str = await self._acall( messages, stop=stop, run_manager=run_manager, **kwargs ) message = AIMessage(content=output_str) generation = ChatGeneration(message=message) return ChatResult(generations=[generation]) @property def _identifying_params(self) -> Dict[str, Any]: """识别聊天模型的关键参数,用于记录或跟踪目的。 返回: 关键配置参数的字典。 """ return { "model": self.model, "temperature": self.temperature, "top_p": self.top_p, "max_tokens": self.max_tokens, }