Source code for langchain_community.llms.baseten
import logging
import os
from typing import Any, Dict, List, Mapping, Optional
import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Field
logger = logging.getLogger(__name__)
[docs]class Baseten(LLM):
"""基本模型
该模块允许使用托管在Baseten上的LLMs。
在Baseten上部署的LLM必须具有以下属性:
* 必须接受以键"prompt"为键的字典作为输入
* 可以接受通过kwargs传递的字典中的其他输入
* 必须返回包含模型输出的字符串
要使用此模块,您必须:
* 将您的Baseten API密钥导出为环境变量`BASETEN_API_KEY`
* 从您的Baseten仪表板获取模型的模型ID
* 确定模型部署(对于所有模型库模型为"production")
这些代码示例使用
[Mistral 7B Instruct](https://app.baseten.co/explore/mistral_7b_instruct)
来自Baseten的模型库。
示例:
.. code-block:: python
from langchain_community.llms import Baseten
# 生产部署
mistral = Baseten(model="MODEL_ID", deployment="production")
mistral("什么是Mistral风?")
.. code-block:: python
from langchain_community.llms import Baseten
# 开发部署
mistral = Baseten(model="MODEL_ID", deployment="development")
mistral("什么是Mistral风?")
.. code-block:: python
from langchain_community.llms import Baseten
# 其他已发布的部署
mistral = Baseten(model="MODEL_ID", deployment="DEPLOYMENT_ID")
mistral("什么是Mistral风?")
"""
model: str
deployment: str
input: Dict[str, Any] = Field(default_factory=dict)
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""获取识别参数。"""
return {
**{"model_kwargs": self.model_kwargs},
}
@property
def _llm_type(self) -> str:
"""模型的返回类型。"""
return "baseten"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
baseten_api_key = os.environ["BASETEN_API_KEY"]
model_id = self.model
if self.deployment == "production":
model_url = f"https://model-{model_id}.api.baseten.co/production/predict"
elif self.deployment == "development":
model_url = f"https://model-{model_id}.api.baseten.co/development/predict"
else: # try specific deployment ID
model_url = f"https://model-{model_id}.api.baseten.co/deployment/{self.deployment}/predict"
response = requests.post(
model_url,
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={"prompt": prompt, **kwargs},
)
return response.json()