Source code for langchain_community.llms.aviary

import dataclasses
import os
from typing import Any, Dict, List, Mapping, Optional, Union, cast

import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Extra, root_validator
from langchain_core.utils import get_from_dict_or_env

from langchain_community.llms.utils import enforce_stop_tokens

TIMEOUT = 60


[docs]@dataclasses.dataclass class AviaryBackend: """鸟舍后端。 属性: backend_url: 鸟舍后端的URL。 bearer: 鸟舍后端的令牌。""" backend_url: str bearer: str def __post_init__(self) -> None: self.header = {"Authorization": self.bearer}
[docs] @classmethod def from_env(cls) -> "AviaryBackend": aviary_url = os.getenv("AVIARY_URL") assert aviary_url, "AVIARY_URL must be set" aviary_token = os.getenv("AVIARY_TOKEN", "") bearer = f"Bearer {aviary_token}" if aviary_token else "" aviary_url += "/" if not aviary_url.endswith("/") else "" return cls(aviary_url, bearer)
[docs]def get_models() -> List[str]: """列出可用的模型""" backend = AviaryBackend.from_env() request_url = backend.backend_url + "-/routes" response = requests.get(request_url, headers=backend.header, timeout=TIMEOUT) try: result = response.json() except requests.JSONDecodeError as e: raise RuntimeError( f"Error decoding JSON from {request_url}. Text response: {response.text}" ) from e result = sorted( [k.lstrip("/").replace("--", "/") for k in result.keys() if "--" in k] ) return result
[docs]def get_completions( model: str, prompt: str, use_prompt_format: bool = True, version: str = "", ) -> Dict[str, Union[str, float, int]]: """从 Aviary 模型中获取完成结果。""" backend = AviaryBackend.from_env() url = backend.backend_url + model.replace("/", "--") + "/" + version + "query" response = requests.post( url, headers=backend.header, json={"prompt": prompt, "use_prompt_format": use_prompt_format}, timeout=TIMEOUT, ) try: return response.json() except requests.JSONDecodeError as e: raise RuntimeError( f"Error decoding JSON from {url}. Text response: {response.text}" ) from e
[docs]class Aviary(LLM): """Aviary 托管模型。 Aviary 是托管模型的后端。您可以在以下网址了解更多关于 aviary 的信息: http://github.com/ray-project/aviary 要获取 aviary 支持的模型列表,请按照网站上的说明安装 aviary CLI,然后使用: `aviary models` 必须设置 AVIARY_URL 和 AVIARY_TOKEN 环境变量。 属性: model: 要使用的模型名称。默认为 "amazon/LightGPT"。 aviary_url: Aviary 后端的 URL。默认为 None。 aviary_token: Aviary 后端的 bearer token。默认为 None。 use_prompt_format: 如果为 True,则将忽略模型的提示模板。默认为 True。 version: 用于 Aviary 的 API 版本。默认为 None。 示例: .. code-block:: python from langchain_community.llms import Aviary os.environ["AVIARY_URL"] = "<URL>" os.environ["AVIARY_TOKEN"] = "<TOKEN>" light = Aviary(model='amazon/LightGPT') output = light('How do you make fried rice?') """ model: str = "amazon/LightGPT" aviary_url: Optional[str] = None aviary_token: Optional[str] = None # If True the prompt template for the model will be ignored. use_prompt_format: bool = True # API version to use for Aviary version: Optional[str] = None class Config: """此pydantic对象的配置。""" extra = Extra.forbid @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """验证环境中是否存在API密钥和Python包。""" aviary_url = get_from_dict_or_env(values, "aviary_url", "AVIARY_URL") aviary_token = get_from_dict_or_env(values, "aviary_token", "AVIARY_TOKEN") # Set env viarables for aviary sdk os.environ["AVIARY_URL"] = aviary_url os.environ["AVIARY_TOKEN"] = aviary_token try: aviary_models = get_models() except requests.exceptions.RequestException as e: raise ValueError(e) model = values.get("model") if model and model not in aviary_models: raise ValueError(f"{aviary_url} does not support model {values['model']}.") return values @property def _identifying_params(self) -> Mapping[str, Any]: """获取识别参数。""" return { "model_name": self.model, "aviary_url": self.aviary_url, } @property def _llm_type(self) -> str: """llm的返回类型。""" return f"aviary-{self.model.replace('/', '-')}" def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """呼叫Aviary 参数: prompt: 传递给模型的提示。 返回: 模型生成的字符串。 示例: .. code-block:: python response = aviary("告诉我一个笑话。") """ kwargs = {"use_prompt_format": self.use_prompt_format} if self.version: kwargs["version"] = self.version output = get_completions( model=self.model, prompt=prompt, **kwargs, ) text = cast(str, output["generated_text"]) if stop: text = enforce_stop_tokens(text, stop) return text