import json
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
from aiohttp import ClientSession
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_community.utilities.requests import Requests
def _message_role(type: str) -> str:
role_mapping = {"ai": "assistant", "human": "user", "chat": "user"}
if type in role_mapping:
return role_mapping[type]
else:
raise ValueError(f"Unknown type: {type}")
def _format_edenai_messages(messages: List[BaseMessage]) -> Dict[str, Any]:
system = None
formatted_messages = []
text = messages[-1].content
for i, message in enumerate(messages[:-1]):
if message.type == "system":
if i != 0:
raise ValueError("System message must be at beginning of message list.")
system = message.content
else:
formatted_messages.append(
{
"role": _message_role(message.type),
"message": message.content,
}
)
return {
"text": text,
"previous_history": formatted_messages,
"chatbot_global_action": system,
}
[docs]class ChatEdenAI(BaseChatModel):
"""`EdenAI`聊天大型语言模型。
`EdenAI`是一个多功能平台,允许您访问来自不同提供商(如Google、OpenAI、Cohere、Mistral等)的各种语言模型。
要开始,请确保已设置环境变量``EDENAI_API_KEY``,并使用您的API密钥,或将其作为命名参数传递给构造函数。
此外,`EdenAI`提供了灵活性,可以从各种模型中进行选择,包括像"gpt-4"这样的模型。
示例:
.. code-block:: python
from langchain_community.chat_models import ChatEdenAI
from langchain_core.messages import HumanMessage
# 使用所需配置初始化`ChatEdenAI`
chat = ChatEdenAI(
provider="openai",
model="gpt-4",
max_tokens=256,
temperature=0.75)
# 创建要与模型交互的消息列表
messages = [HumanMessage(content="hello")]
# 使用提供的消息调用模型
chat.invoke(messages)
`EdenAI`不仅限于简单的模型调用。它为您提供了高级功能:
- **多个提供商** :访问由各种提供商提供的各种语言模型,让您可以自由选择最适合您用例的模型。
- **备用机制** :设置备用机制以确保无缝运行,即使主要提供商不可用,也可以轻松切换到备用提供商。
- **使用统计** :基于每个项目和每个API密钥的使用统计。
此功能允许您有效监视和管理资源消耗。
- **监控和可观察性** :`EdenAI`在平台上提供了全面的监控和可观察性工具。
设置备用机制的示例:
.. code-block:: python
# 使用备用提供商初始化`ChatEdenAI`
chat_with_fallback = ChatEdenAI(
provider="openai",
model="gpt-4",
max_tokens=256,
temperature=0.75,
fallback_provider="google")
您可以在此处找到更多详细信息:https://docs.edenai.co/reference/text_chat_create"""
provider: str = "openai"
"""聊天提供者使用(例如:openai、google等)"""
model: Optional[str] = None
"""上述提供商的模型名称(例如:对于OpenAI的'gpt-4')
可用模型显示在https://docs.edenai.co/的'available providers'下。"""
max_tokens: int = 256
"""表示每代要预测的令牌数量。"""
temperature: Optional[float] = 0
"""一个非负浮点数,用于调整生成过程中的随机程度。"""
streaming: bool = False
"""是否流式传输结果。"""
fallback_providers: Optional[str] = None
"""如果调用提供程序失败,将使用此中的提供程序作为后备。"""
edenai_api_url: str = "https://api.edenai.run/v2"
edenai_api_key: Optional[SecretStr] = Field(None, description="EdenAI API Token")
class Config:
"""此pydantic对象的配置。"""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""验证环境中是否存在API密钥。"""
values["edenai_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "edenai_api_key", "EDENAI_API_KEY")
)
return values
[docs] @staticmethod
def get_user_agent() -> str:
from langchain_community import __version__
return f"langchain/{__version__}"
@property
def _llm_type(self) -> str:
"""聊天模型的返回类型。"""
return "edenai-chat"
@property
def _api_key(self) -> str:
if self.edenai_api_key:
return self.edenai_api_key.get_secret_value()
return ""
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""调用EdenAI的聊天端点。"""
url = f"{self.edenai_api_url}/text/chat/stream"
headers = {
"Authorization": f"Bearer {self._api_key}",
"User-Agent": self.get_user_agent(),
}
formatted_data = _format_edenai_messages(messages=messages)
payload: Dict[str, Any] = {
"providers": self.provider,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"fallback_providers": self.fallback_providers,
**formatted_data,
**kwargs,
}
payload = {k: v for k, v in payload.items() if v is not None}
if self.model is not None:
payload["settings"] = {self.provider: self.model}
request = Requests(headers=headers)
response = request.post(url=url, data=payload, stream=True)
response.raise_for_status()
for chunk_response in response.iter_lines():
chunk = json.loads(chunk_response.decode())
token = chunk["text"]
cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))
if run_manager:
run_manager.on_llm_new_token(token, chunk=cg_chunk)
yield cg_chunk
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
url = f"{self.edenai_api_url}/text/chat/stream"
headers = {
"Authorization": f"Bearer {self._api_key}",
"User-Agent": self.get_user_agent(),
}
formatted_data = _format_edenai_messages(messages=messages)
payload: Dict[str, Any] = {
"providers": self.provider,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"fallback_providers": self.fallback_providers,
**formatted_data,
**kwargs,
}
payload = {k: v for k, v in payload.items() if v is not None}
if self.model is not None:
payload["settings"] = {self.provider: self.model}
async with ClientSession() as session:
async with session.post(url, json=payload, headers=headers) as response:
response.raise_for_status()
async for chunk_response in response.content:
chunk = json.loads(chunk_response.decode())
token = chunk["text"]
cg_chunk = ChatGenerationChunk(
message=AIMessageChunk(content=token)
)
if run_manager:
await run_manager.on_llm_new_token(
token=chunk["text"], chunk=cg_chunk
)
yield cg_chunk
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""调用EdenAI的聊天端点。"""
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
url = f"{self.edenai_api_url}/text/chat"
headers = {
"Authorization": f"Bearer {self._api_key}",
"User-Agent": self.get_user_agent(),
}
formatted_data = _format_edenai_messages(messages=messages)
payload: Dict[str, Any] = {
"providers": self.provider,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"fallback_providers": self.fallback_providers,
**formatted_data,
**kwargs,
}
payload = {k: v for k, v in payload.items() if v is not None}
if self.model is not None:
payload["settings"] = {self.provider: self.model}
request = Requests(headers=headers)
response = request.post(url=url, data=payload)
response.raise_for_status()
data = response.json()
provider_response = data[self.provider]
if self.fallback_providers:
fallback_response = data.get(self.fallback_providers)
if fallback_response:
provider_response = fallback_response
if provider_response.get("status") == "fail":
err_msg = provider_response.get("error", {}).get("message")
raise Exception(err_msg)
return ChatResult(
generations=[
ChatGeneration(
message=AIMessage(content=provider_response["generated_text"])
)
],
llm_output=data,
)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
url = f"{self.edenai_api_url}/text/chat"
headers = {
"Authorization": f"Bearer {self._api_key}",
"User-Agent": self.get_user_agent(),
}
formatted_data = _format_edenai_messages(messages=messages)
payload: Dict[str, Any] = {
"providers": self.provider,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"fallback_providers": self.fallback_providers,
**formatted_data,
**kwargs,
}
payload = {k: v for k, v in payload.items() if v is not None}
if self.model is not None:
payload["settings"] = {self.provider: self.model}
async with ClientSession() as session:
async with session.post(url, json=payload, headers=headers) as response:
response.raise_for_status()
data = await response.json()
provider_response = data[self.provider]
if self.fallback_providers:
fallback_response = data.get(self.fallback_providers)
if fallback_response:
provider_response = fallback_response
if provider_response.get("status") == "fail":
err_msg = provider_response.get("error", {}).get("message")
raise Exception(err_msg)
return ChatResult(
generations=[
ChatGeneration(
message=AIMessage(
content=provider_response["generated_text"]
)
)
],
llm_output=data,
)