from __future__ import annotations
from typing import Any, Dict, Iterator, List, Mapping, Optional, cast
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_community.llms.volcengine_maas import VolcEngineMaasBase
def _convert_message_to_dict(message: BaseMessage) -> dict:
if isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {"role": "function", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict
[docs]def convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage:
"""将字典转换为消息。"""
content = _dict.get("choice", {}).get("message", {}).get("content", "")
return AIMessage(content=content)
[docs]class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
"""Volc Engine Maas托管了大量的模型。
您可以通过这个类来利用这些模型。
要使用,您应该安装``volcengine`` python包。
并通过环境变量或直接传递这些密钥来设置访问密钥和秘密密钥。
访问密钥、秘密密钥是必需的参数,您可以在 https://www.volcengine.com/docs/6291/65568 获取帮助。
为了使用它们,必须安装 'volcengine' Python包。
访问密钥和秘密密钥必须通过环境变量或直接传递给这个类来设置。
访问密钥和秘密密钥是必填参数,可以在 https://www.volcengine.com/docs/6291/65568 寻求帮助。
两种方法如下:
* 环境变量
使用您的访问密钥和秘密密钥设置环境变量 'VOLC_ACCESSKEY' 和 'VOLC_SECRETKEY'。
* 直接传递给类
示例:
.. code-block:: python
from langchain_community.llms import VolcEngineMaasLLM
model = VolcEngineMaasChat(model="skylark-lite-public",
volc_engine_maas_ak="your_ak",
volc_engine_maas_sk="your_sk")
"""
@property
def _llm_type(self) -> str:
"""聊天模型的返回类型。"""
return "volc-engine-maas-chat"
[docs] @classmethod
def is_lc_serializable(cls) -> bool:
"""返回此模型是否可以被Langchain序列化。"""
return False
@property
def _identifying_params(self) -> Dict[str, Any]:
return {
**{"endpoint": self.endpoint, "model": self.model},
**super()._identifying_params,
}
def _convert_prompt_msg_params(
self,
messages: List[BaseMessage],
**kwargs: Any,
) -> Dict[str, Any]:
model_req = {
"model": {
"name": self.model,
}
}
if self.model_version is not None:
model_req["model"]["version"] = self.model_version
return {
**model_req,
"messages": [_convert_message_to_dict(message) for message in messages],
"parameters": {**self._default_params, **kwargs},
}
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
if stop is not None:
kwargs["stop"] = stop
params = self._convert_prompt_msg_params(messages, **kwargs)
for res in self.client.stream_chat(params):
if res:
msg = convert_dict_to_message(res)
chunk = ChatGenerationChunk(message=AIMessageChunk(content=msg.content))
if run_manager:
run_manager.on_llm_new_token(cast(str, msg.content), chunk=chunk)
yield chunk
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
completion = ""
if self.streaming:
for chunk in self._stream(messages, stop, run_manager, **kwargs):
completion += chunk.text
else:
if stop is not None:
kwargs["stop"] = stop
params = self._convert_prompt_msg_params(messages, **kwargs)
res = self.client.chat(params)
msg = convert_dict_to_message(res)
completion = cast(str, msg.content)
message = AIMessage(content=completion)
return ChatResult(generations=[ChatGeneration(message=message)])