from __future__ import annotations
from typing import Any, Dict, Iterator, List, Optional
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
[docs]class VolcEngineMaasBase(BaseModel):
"""用于VolcEngineMaas模型的基类。"""
client: Any
volc_engine_maas_ak: Optional[SecretStr] = None
"""用于Volc引擎的访问密钥"""
volc_engine_maas_sk: Optional[SecretStr] = None
"""用于 Volc 引擎的密钥"""
endpoint: Optional[str] = "maas-api.ml-platform-cn-beijing.volces.com"
"""VolcEngineMaas LLM的端点。"""
region: Optional[str] = "Region"
"""VolcEngineMaas LLM 的区域。"""
model: str = "skylark-lite-public"
"""模型名称。您可以在此处查看此模型的详细信息
https://www.volcengine.com/docs/82379/1133187
您可以通过更改此字段来选择其他模型"""
model_version: Optional[str] = None
"""模型版本。仅在Moonshot大型语言模型中使用。
您可以在此处查看详细信息 https://www.volcengine.com/docs/82379/1158281"""
top_p: Optional[float] = 0.8
"""每一步需要考虑的标记的总概率质量。"""
temperature: Optional[float] = 0.95
"""一个非负浮点数,用于调整生成过程中的随机程度。"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""模型特殊参数,您可以在模型页面上查看详细信息。"""
streaming: bool = False
"""是否流式传输结果。"""
connect_timeout: Optional[int] = 60
"""连接到volc引擎maas端点的超时时间。默认为60秒。"""
read_timeout: Optional[int] = 60
"""从volc引擎maas端点读取响应的超时时间。
默认值为60秒。"""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
volc_engine_maas_ak = convert_to_secret_str(
get_from_dict_or_env(values, "volc_engine_maas_ak", "VOLC_ACCESSKEY")
)
volc_engine_maas_sk = convert_to_secret_str(
get_from_dict_or_env(values, "volc_engine_maas_sk", "VOLC_SECRETKEY")
)
endpoint = values["endpoint"]
if values["endpoint"] is not None and values["endpoint"] != "":
endpoint = values["endpoint"]
try:
from volcengine.maas import MaasService
maas = MaasService(
endpoint,
values["region"],
connection_timeout=values["connect_timeout"],
socket_timeout=values["read_timeout"],
)
maas.set_ak(volc_engine_maas_ak.get_secret_value())
maas.set_sk(volc_engine_maas_sk.get_secret_value())
values["volc_engine_maas_ak"] = volc_engine_maas_ak
values["volc_engine_maas_sk"] = volc_engine_maas_sk
values["client"] = maas
except ImportError:
raise ImportError(
"volcengine package not found, please install it with "
"`pip install volcengine`"
)
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""获取调用 VolcEngineMaas API 的默认参数。"""
normal_params = {
"top_p": self.top_p,
"temperature": self.temperature,
}
return {**normal_params, **self.model_kwargs}
[docs]class VolcEngineMaasLLM(LLM, VolcEngineMaasBase):
"""volc engine maas 主机拥有大量的模型。
您可以通过这个类来利用这些模型。
要使用,您应该已经安装了``volcengine`` python包。
并通过环境变量或直接传递给这个类来设置访问密钥和秘密密钥。
访问密钥、秘密密钥是必需的参数,您可以在以下链接获取帮助
https://www.volcengine.com/docs/6291/65568
为了使用它们,必须安装 'volcengine' Python 包。
访问密钥和秘密密钥必须通过环境变量或直接传递给这个类来设置。
访问密钥和秘密密钥是必填参数,可以在 https://www.volcengine.com/docs/6291/65568 寻求帮助。
示例:
.. code-block:: python
from langchain_community.llms import VolcEngineMaasLLM
model = VolcEngineMaasLLM(model="skylark-lite-public",
volc_engine_maas_ak="your_ak",
volc_engine_maas_sk="your_sk")
"""
@property
def _llm_type(self) -> str:
"""llm的返回类型。"""
return "volc-engine-maas-llm"
def _convert_prompt_msg_params(
self,
prompt: str,
**kwargs: Any,
) -> dict:
model_req = {
"model": {
"name": self.model,
}
}
if self.model_version is not None:
model_req["model"]["version"] = self.model_version
return {
**model_req,
"messages": [{"role": "user", "content": prompt}],
"parameters": {**self._default_params, **kwargs},
}
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if self.streaming:
completion = ""
for chunk in self._stream(prompt, stop, run_manager, **kwargs):
completion += chunk.text
return completion
params = self._convert_prompt_msg_params(prompt, **kwargs)
response = self.client.chat(params)
return response.get("choice", {}).get("message", {}).get("content", "")
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
params = self._convert_prompt_msg_params(prompt, **kwargs)
for res in self.client.stream_chat(params):
if res:
chunk = GenerationChunk(
text=res.get("choice", {}).get("message", {}).get("content", "")
)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk