Source code for langchain_community.llms.aviary
import dataclasses
import os
from typing import Any, Dict, List, Mapping, Optional, Union, cast
import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Extra, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_community.llms.utils import enforce_stop_tokens
TIMEOUT = 60
[docs]@dataclasses.dataclass
class AviaryBackend:
"""鸟舍后端。
属性:
backend_url: 鸟舍后端的URL。
bearer: 鸟舍后端的令牌。"""
backend_url: str
bearer: str
def __post_init__(self) -> None:
self.header = {"Authorization": self.bearer}
[docs] @classmethod
def from_env(cls) -> "AviaryBackend":
aviary_url = os.getenv("AVIARY_URL")
assert aviary_url, "AVIARY_URL must be set"
aviary_token = os.getenv("AVIARY_TOKEN", "")
bearer = f"Bearer {aviary_token}" if aviary_token else ""
aviary_url += "/" if not aviary_url.endswith("/") else ""
return cls(aviary_url, bearer)
[docs]def get_models() -> List[str]:
"""列出可用的模型"""
backend = AviaryBackend.from_env()
request_url = backend.backend_url + "-/routes"
response = requests.get(request_url, headers=backend.header, timeout=TIMEOUT)
try:
result = response.json()
except requests.JSONDecodeError as e:
raise RuntimeError(
f"Error decoding JSON from {request_url}. Text response: {response.text}"
) from e
result = sorted(
[k.lstrip("/").replace("--", "/") for k in result.keys() if "--" in k]
)
return result
[docs]def get_completions(
model: str,
prompt: str,
use_prompt_format: bool = True,
version: str = "",
) -> Dict[str, Union[str, float, int]]:
"""从 Aviary 模型中获取完成结果。"""
backend = AviaryBackend.from_env()
url = backend.backend_url + model.replace("/", "--") + "/" + version + "query"
response = requests.post(
url,
headers=backend.header,
json={"prompt": prompt, "use_prompt_format": use_prompt_format},
timeout=TIMEOUT,
)
try:
return response.json()
except requests.JSONDecodeError as e:
raise RuntimeError(
f"Error decoding JSON from {url}. Text response: {response.text}"
) from e
[docs]class Aviary(LLM):
"""Aviary 托管模型。
Aviary 是托管模型的后端。您可以在以下网址了解更多关于 aviary 的信息:
http://github.com/ray-project/aviary
要获取 aviary 支持的模型列表,请按照网站上的说明安装 aviary CLI,然后使用:
`aviary models`
必须设置 AVIARY_URL 和 AVIARY_TOKEN 环境变量。
属性:
model: 要使用的模型名称。默认为 "amazon/LightGPT"。
aviary_url: Aviary 后端的 URL。默认为 None。
aviary_token: Aviary 后端的 bearer token。默认为 None。
use_prompt_format: 如果为 True,则将忽略模型的提示模板。默认为 True。
version: 用于 Aviary 的 API 版本。默认为 None。
示例:
.. code-block:: python
from langchain_community.llms import Aviary
os.environ["AVIARY_URL"] = "<URL>"
os.environ["AVIARY_TOKEN"] = "<TOKEN>"
light = Aviary(model='amazon/LightGPT')
output = light('How do you make fried rice?')
"""
model: str = "amazon/LightGPT"
aviary_url: Optional[str] = None
aviary_token: Optional[str] = None
# If True the prompt template for the model will be ignored.
use_prompt_format: bool = True
# API version to use for Aviary
version: Optional[str] = None
class Config:
"""此pydantic对象的配置。"""
extra = Extra.forbid
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""验证环境中是否存在API密钥和Python包。"""
aviary_url = get_from_dict_or_env(values, "aviary_url", "AVIARY_URL")
aviary_token = get_from_dict_or_env(values, "aviary_token", "AVIARY_TOKEN")
# Set env viarables for aviary sdk
os.environ["AVIARY_URL"] = aviary_url
os.environ["AVIARY_TOKEN"] = aviary_token
try:
aviary_models = get_models()
except requests.exceptions.RequestException as e:
raise ValueError(e)
model = values.get("model")
if model and model not in aviary_models:
raise ValueError(f"{aviary_url} does not support model {values['model']}.")
return values
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""获取识别参数。"""
return {
"model_name": self.model,
"aviary_url": self.aviary_url,
}
@property
def _llm_type(self) -> str:
"""llm的返回类型。"""
return f"aviary-{self.model.replace('/', '-')}"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""呼叫Aviary
参数:
prompt: 传递给模型的提示。
返回:
模型生成的字符串。
示例:
.. code-block:: python
response = aviary("告诉我一个笑话。")
"""
kwargs = {"use_prompt_format": self.use_prompt_format}
if self.version:
kwargs["version"] = self.version
output = get_completions(
model=self.model,
prompt=prompt,
**kwargs,
)
text = cast(str, output["generated_text"])
if stop:
text = enforce_stop_tokens(text, stop)
return text