Source code for langchain_community.llms.baseten

import logging
import os
from typing import Any, Dict, List, Mapping, Optional

import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Field

logger = logging.getLogger(__name__)


[docs]class Baseten(LLM): """基本模型 该模块允许使用托管在Baseten上的LLMs。 在Baseten上部署的LLM必须具有以下属性: * 必须接受以键"prompt"为键的字典作为输入 * 可以接受通过kwargs传递的字典中的其他输入 * 必须返回包含模型输出的字符串 要使用此模块,您必须: * 将您的Baseten API密钥导出为环境变量`BASETEN_API_KEY` * 从您的Baseten仪表板获取模型的模型ID * 确定模型部署(对于所有模型库模型为"production") 这些代码示例使用 [Mistral 7B Instruct](https://app.baseten.co/explore/mistral_7b_instruct) 来自Baseten的模型库。 示例: .. code-block:: python from langchain_community.llms import Baseten # 生产部署 mistral = Baseten(model="MODEL_ID", deployment="production") mistral("什么是Mistral风?") .. code-block:: python from langchain_community.llms import Baseten # 开发部署 mistral = Baseten(model="MODEL_ID", deployment="development") mistral("什么是Mistral风?") .. code-block:: python from langchain_community.llms import Baseten # 其他已发布的部署 mistral = Baseten(model="MODEL_ID", deployment="DEPLOYMENT_ID") mistral("什么是Mistral风?") """ model: str deployment: str input: Dict[str, Any] = Field(default_factory=dict) model_kwargs: Dict[str, Any] = Field(default_factory=dict) @property def _identifying_params(self) -> Mapping[str, Any]: """获取识别参数。""" return { **{"model_kwargs": self.model_kwargs}, } @property def _llm_type(self) -> str: """模型的返回类型。""" return "baseten" def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: baseten_api_key = os.environ["BASETEN_API_KEY"] model_id = self.model if self.deployment == "production": model_url = f"https://model-{model_id}.api.baseten.co/production/predict" elif self.deployment == "development": model_url = f"https://model-{model_id}.api.baseten.co/development/predict" else: # try specific deployment ID model_url = f"https://model-{model_id}.api.baseten.co/deployment/{self.deployment}/predict" response = requests.post( model_url, headers={"Authorization": f"Api-Key {baseten_api_key}"}, json={"prompt": prompt, **kwargs}, ) return response.json()