"""EdenAI的Generation API的封装。"""
import logging
from typing import Any, Dict, List, Literal, Optional
from aiohttp import ClientSession
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_community.llms.utils import enforce_stop_tokens
from langchain_community.utilities.requests import Requests
logger = logging.getLogger(__name__)
[docs]class EdenAI(LLM):
"""EdenAI 模型。
要使用,您应该设置环境变量``EDENAI_API_KEY``为您的 API 令牌。
您可以在此处找到您的令牌: https://app.edenai.run/admin/account/settings
`feature` 和 `subfeature` 是必需的,但任何其他模型参数也可以以 params={model_param: value, ...} 的格式传递。
有关 API 参考,请查看 edenai 文档: http://docs.edenai.co."""
base_url: str = "https://api.edenai.run/v2"
edenai_api_key: Optional[str] = None
feature: Literal["text", "image"] = "text"
"""默认使用文本生成特征。"""
subfeature: Literal["generation"] = "generation"
"""上述功能的子功能,默认使用生成功能"""
provider: str
"""生成提供者使用(例如:openai,stabilityai,cohere,google等)。"""
model: Optional[str] = None
"""上述提供商的模型名称(例如:'gpt-3.5-turbo-instruct' 适用于 openai)
可用模型显示在 https://docs.edenai.co/ 的 'available providers' 下。"""
# Optional parameters to add depending of chosen feature
# see api reference for more infos
temperature: Optional[float] = Field(default=None, ge=0, le=1) # for text
max_tokens: Optional[int] = Field(default=None, ge=0) # for text
resolution: Optional[Literal["256x256", "512x512", "1024x1024"]] = None # for image
params: Dict[str, Any] = Field(default_factory=dict)
"""已弃用:直接使用temperature、max_tokens、resolution
传递给API的可选参数"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""额外参数"""
stop_sequences: Optional[List[str]] = None
"""停止使用序列。"""
class Config:
"""此pydantic对象的配置。"""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""验证环境中是否存在API密钥。"""
values["edenai_api_key"] = get_from_dict_or_env(
values, "edenai_api_key", "EDENAI_API_KEY"
)
return values
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""从传入的额外参数构建额外的kwargs。"""
all_required_field_names = {field.alias for field in cls.__fields__.values()}
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name not in all_required_field_names:
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
logger.warning(
f"""{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
values["model_kwargs"] = extra
return values
@property
def _llm_type(self) -> str:
"""模型的返回类型。"""
return "edenai"
def _format_output(self, output: dict) -> str:
if self.feature == "text":
return output[self.provider]["generated_text"]
else:
return output[self.provider]["items"][0]["image"]
[docs] @staticmethod
def get_user_agent() -> str:
from langchain_community import __version__
return f"langchain/{__version__}"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""调用EdenAI的文本生成端点。
参数:
prompt:传递给模型的提示。
返回:
以json格式化的字符串响应。
"""
stops = None
if self.stop_sequences is not None and stop is not None:
raise ValueError(
"stop sequences found in both the input and default params."
)
elif self.stop_sequences is not None:
stops = self.stop_sequences
else:
stops = stop
url = f"{self.base_url}/{self.feature}/{self.subfeature}"
headers = {
"Authorization": f"Bearer {self.edenai_api_key}",
"User-Agent": self.get_user_agent(),
}
payload: Dict[str, Any] = {
"providers": self.provider,
"text": prompt,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"resolution": self.resolution,
**self.params,
**kwargs,
"num_images": 1, # always limit to 1 (ignored for text)
}
# filter None values to not pass them to the http payload
payload = {k: v for k, v in payload.items() if v is not None}
if self.model is not None:
payload["settings"] = {self.provider: self.model}
request = Requests(headers=headers)
response = request.post(url=url, data=payload)
if response.status_code >= 500:
raise Exception(f"EdenAI Server: Error {response.status_code}")
elif response.status_code >= 400:
raise ValueError(f"EdenAI received an invalid payload: {response.text}")
elif response.status_code != 200:
raise Exception(
f"EdenAI returned an unexpected response with status "
f"{response.status_code}: {response.text}"
)
data = response.json()
provider_response = data[self.provider]
if provider_response.get("status") == "fail":
err_msg = provider_response.get("error", {}).get("message")
raise Exception(err_msg)
output = self._format_output(data)
if stops is not None:
output = enforce_stop_tokens(output, stops)
return output
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""调用EdenAi模型根据提示获取预测。
参数:
prompt: 传递给模型的提示。
stop: 一个停用词列表(可选)。
run_manager: 用于与LLMs进行异步交互的回调管理器。
返回:
模型生成的字符串。
"""
stops = None
if self.stop_sequences is not None and stop is not None:
raise ValueError(
"stop sequences found in both the input and default params."
)
elif self.stop_sequences is not None:
stops = self.stop_sequences
else:
stops = stop
url = f"{self.base_url}/{self.feature}/{self.subfeature}"
headers = {
"Authorization": f"Bearer {self.edenai_api_key}",
"User-Agent": self.get_user_agent(),
}
payload: Dict[str, Any] = {
"providers": self.provider,
"text": prompt,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"resolution": self.resolution,
**self.params,
**kwargs,
"num_images": 1, # always limit to 1 (ignored for text)
}
# filter `None` values to not pass them to the http payload as null
payload = {k: v for k, v in payload.items() if v is not None}
if self.model is not None:
payload["settings"] = {self.provider: self.model}
async with ClientSession() as session:
async with session.post(url, json=payload, headers=headers) as response:
if response.status >= 500:
raise Exception(f"EdenAI Server: Error {response.status}")
elif response.status >= 400:
raise ValueError(
f"EdenAI received an invalid payload: {response.text}"
)
elif response.status != 200:
raise Exception(
f"EdenAI returned an unexpected response with status "
f"{response.status}: {response.text}"
)
response_json = await response.json()
provider_response = response_json[self.provider]
if provider_response.get("status") == "fail":
err_msg = provider_response.get("error", {}).get("message")
raise Exception(err_msg)
output = self._format_output(response_json)
if stops is not None:
output = enforce_stop_tokens(output, stops)
return output