Source code for langchain_community.llms.volcengine_maas

from __future__ import annotations

from typing import Any, Dict, Iterator, List, Optional

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env


[docs]class VolcEngineMaasBase(BaseModel): """用于VolcEngineMaas模型的基类。""" client: Any volc_engine_maas_ak: Optional[SecretStr] = None """用于Volc引擎的访问密钥""" volc_engine_maas_sk: Optional[SecretStr] = None """用于 Volc 引擎的密钥""" endpoint: Optional[str] = "maas-api.ml-platform-cn-beijing.volces.com" """VolcEngineMaas LLM的端点。""" region: Optional[str] = "Region" """VolcEngineMaas LLM 的区域。""" model: str = "skylark-lite-public" """模型名称。您可以在此处查看此模型的详细信息 https://www.volcengine.com/docs/82379/1133187 您可以通过更改此字段来选择其他模型""" model_version: Optional[str] = None """模型版本。仅在Moonshot大型语言模型中使用。 您可以在此处查看详细信息 https://www.volcengine.com/docs/82379/1158281""" top_p: Optional[float] = 0.8 """每一步需要考虑的标记的总概率质量。""" temperature: Optional[float] = 0.95 """一个非负浮点数,用于调整生成过程中的随机程度。""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) """模型特殊参数,您可以在模型页面上查看详细信息。""" streaming: bool = False """是否流式传输结果。""" connect_timeout: Optional[int] = 60 """连接到volc引擎maas端点的超时时间。默认为60秒。""" read_timeout: Optional[int] = 60 """从volc引擎maas端点读取响应的超时时间。 默认值为60秒。""" @root_validator() def validate_environment(cls, values: Dict) -> Dict: volc_engine_maas_ak = convert_to_secret_str( get_from_dict_or_env(values, "volc_engine_maas_ak", "VOLC_ACCESSKEY") ) volc_engine_maas_sk = convert_to_secret_str( get_from_dict_or_env(values, "volc_engine_maas_sk", "VOLC_SECRETKEY") ) endpoint = values["endpoint"] if values["endpoint"] is not None and values["endpoint"] != "": endpoint = values["endpoint"] try: from volcengine.maas import MaasService maas = MaasService( endpoint, values["region"], connection_timeout=values["connect_timeout"], socket_timeout=values["read_timeout"], ) maas.set_ak(volc_engine_maas_ak.get_secret_value()) maas.set_sk(volc_engine_maas_sk.get_secret_value()) values["volc_engine_maas_ak"] = volc_engine_maas_ak values["volc_engine_maas_sk"] = volc_engine_maas_sk values["client"] = maas except ImportError: raise ImportError( "volcengine package not found, please install it with " "`pip install volcengine`" ) return values @property def _default_params(self) -> Dict[str, Any]: """获取调用 VolcEngineMaas API 的默认参数。""" normal_params = { "top_p": self.top_p, "temperature": self.temperature, } return {**normal_params, **self.model_kwargs}
[docs]class VolcEngineMaasLLM(LLM, VolcEngineMaasBase): """volc engine maas 主机拥有大量的模型。 您可以通过这个类来利用这些模型。 要使用,您应该已经安装了``volcengine`` python包。 并通过环境变量或直接传递给这个类来设置访问密钥和秘密密钥。 访问密钥、秘密密钥是必需的参数,您可以在以下链接获取帮助 https://www.volcengine.com/docs/6291/65568 为了使用它们,必须安装 'volcengine' Python 包。 访问密钥和秘密密钥必须通过环境变量或直接传递给这个类来设置。 访问密钥和秘密密钥是必填参数,可以在 https://www.volcengine.com/docs/6291/65568 寻求帮助。 示例: .. code-block:: python from langchain_community.llms import VolcEngineMaasLLM model = VolcEngineMaasLLM(model="skylark-lite-public", volc_engine_maas_ak="your_ak", volc_engine_maas_sk="your_sk") """ @property def _llm_type(self) -> str: """llm的返回类型。""" return "volc-engine-maas-llm" def _convert_prompt_msg_params( self, prompt: str, **kwargs: Any, ) -> dict: model_req = { "model": { "name": self.model, } } if self.model_version is not None: model_req["model"]["version"] = self.model_version return { **model_req, "messages": [{"role": "user", "content": prompt}], "parameters": {**self._default_params, **kwargs}, } def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: if self.streaming: completion = "" for chunk in self._stream(prompt, stop, run_manager, **kwargs): completion += chunk.text return completion params = self._convert_prompt_msg_params(prompt, **kwargs) response = self.client.chat(params) return response.get("choice", {}).get("message", {}).get("content", "") def _stream( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: params = self._convert_prompt_msg_params(prompt, **kwargs) for res in self.client.stream_chat(params): if res: chunk = GenerationChunk( text=res.get("choice", {}).get("message", {}).get("content", "") ) if run_manager: run_manager.on_llm_new_token(chunk.text, chunk=chunk) yield chunk