Source code for langchain_community.chat_models.edenai

import json
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional

from aiohttp import ClientSession
from langchain_core.callbacks import (
    AsyncCallbackManagerForLLMRun,
    CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
    BaseChatModel,
    agenerate_from_stream,
    generate_from_stream,
)
from langchain_core.messages import (
    AIMessage,
    AIMessageChunk,
    BaseMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env

from langchain_community.utilities.requests import Requests


def _message_role(type: str) -> str:
    role_mapping = {"ai": "assistant", "human": "user", "chat": "user"}

    if type in role_mapping:
        return role_mapping[type]
    else:
        raise ValueError(f"Unknown type: {type}")


def _format_edenai_messages(messages: List[BaseMessage]) -> Dict[str, Any]:
    system = None
    formatted_messages = []
    text = messages[-1].content
    for i, message in enumerate(messages[:-1]):
        if message.type == "system":
            if i != 0:
                raise ValueError("System message must be at beginning of message list.")
            system = message.content
        else:
            formatted_messages.append(
                {
                    "role": _message_role(message.type),
                    "message": message.content,
                }
            )
    return {
        "text": text,
        "previous_history": formatted_messages,
        "chatbot_global_action": system,
    }


[docs]class ChatEdenAI(BaseChatModel): """`EdenAI`聊天大型语言模型。 `EdenAI`是一个多功能平台,允许您访问来自不同提供商(如Google、OpenAI、Cohere、Mistral等)的各种语言模型。 要开始,请确保已设置环境变量``EDENAI_API_KEY``,并使用您的API密钥,或将其作为命名参数传递给构造函数。 此外,`EdenAI`提供了灵活性,可以从各种模型中进行选择,包括像"gpt-4"这样的模型。 示例: .. code-block:: python from langchain_community.chat_models import ChatEdenAI from langchain_core.messages import HumanMessage # 使用所需配置初始化`ChatEdenAI` chat = ChatEdenAI( provider="openai", model="gpt-4", max_tokens=256, temperature=0.75) # 创建要与模型交互的消息列表 messages = [HumanMessage(content="hello")] # 使用提供的消息调用模型 chat.invoke(messages) `EdenAI`不仅限于简单的模型调用。它为您提供了高级功能: - **多个提供商** :访问由各种提供商提供的各种语言模型,让您可以自由选择最适合您用例的模型。 - **备用机制** :设置备用机制以确保无缝运行,即使主要提供商不可用,也可以轻松切换到备用提供商。 - **使用统计** :基于每个项目和每个API密钥的使用统计。 此功能允许您有效监视和管理资源消耗。 - **监控和可观察性** :`EdenAI`在平台上提供了全面的监控和可观察性工具。 设置备用机制的示例: .. code-block:: python # 使用备用提供商初始化`ChatEdenAI` chat_with_fallback = ChatEdenAI( provider="openai", model="gpt-4", max_tokens=256, temperature=0.75, fallback_provider="google") 您可以在此处找到更多详细信息:https://docs.edenai.co/reference/text_chat_create""" provider: str = "openai" """聊天提供者使用(例如:openai、google等)""" model: Optional[str] = None """上述提供商的模型名称(例如:对于OpenAI的'gpt-4') 可用模型显示在https://docs.edenai.co/的'available providers'下。""" max_tokens: int = 256 """表示每代要预测的令牌数量。""" temperature: Optional[float] = 0 """一个非负浮点数,用于调整生成过程中的随机程度。""" streaming: bool = False """是否流式传输结果。""" fallback_providers: Optional[str] = None """如果调用提供程序失败,将使用此中的提供程序作为后备。""" edenai_api_url: str = "https://api.edenai.run/v2" edenai_api_key: Optional[SecretStr] = Field(None, description="EdenAI API Token") class Config: """此pydantic对象的配置。""" extra = Extra.forbid @root_validator() def validate_environment(cls, values: Dict) -> Dict: """验证环境中是否存在API密钥。""" values["edenai_api_key"] = convert_to_secret_str( get_from_dict_or_env(values, "edenai_api_key", "EDENAI_API_KEY") ) return values
[docs] @staticmethod def get_user_agent() -> str: from langchain_community import __version__ return f"langchain/{__version__}"
@property def _llm_type(self) -> str: """聊天模型的返回类型。""" return "edenai-chat" @property def _api_key(self) -> str: if self.edenai_api_key: return self.edenai_api_key.get_secret_value() return "" def _stream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: """调用EdenAI的聊天端点。""" url = f"{self.edenai_api_url}/text/chat/stream" headers = { "Authorization": f"Bearer {self._api_key}", "User-Agent": self.get_user_agent(), } formatted_data = _format_edenai_messages(messages=messages) payload: Dict[str, Any] = { "providers": self.provider, "max_tokens": self.max_tokens, "temperature": self.temperature, "fallback_providers": self.fallback_providers, **formatted_data, **kwargs, } payload = {k: v for k, v in payload.items() if v is not None} if self.model is not None: payload["settings"] = {self.provider: self.model} request = Requests(headers=headers) response = request.post(url=url, data=payload, stream=True) response.raise_for_status() for chunk_response in response.iter_lines(): chunk = json.loads(chunk_response.decode()) token = chunk["text"] cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=token)) if run_manager: run_manager.on_llm_new_token(token, chunk=cg_chunk) yield cg_chunk async def _astream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: url = f"{self.edenai_api_url}/text/chat/stream" headers = { "Authorization": f"Bearer {self._api_key}", "User-Agent": self.get_user_agent(), } formatted_data = _format_edenai_messages(messages=messages) payload: Dict[str, Any] = { "providers": self.provider, "max_tokens": self.max_tokens, "temperature": self.temperature, "fallback_providers": self.fallback_providers, **formatted_data, **kwargs, } payload = {k: v for k, v in payload.items() if v is not None} if self.model is not None: payload["settings"] = {self.provider: self.model} async with ClientSession() as session: async with session.post(url, json=payload, headers=headers) as response: response.raise_for_status() async for chunk_response in response.content: chunk = json.loads(chunk_response.decode()) token = chunk["text"] cg_chunk = ChatGenerationChunk( message=AIMessageChunk(content=token) ) if run_manager: await run_manager.on_llm_new_token( token=chunk["text"], chunk=cg_chunk ) yield cg_chunk def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: """调用EdenAI的聊天端点。""" if self.streaming: stream_iter = self._stream( messages, stop=stop, run_manager=run_manager, **kwargs ) return generate_from_stream(stream_iter) url = f"{self.edenai_api_url}/text/chat" headers = { "Authorization": f"Bearer {self._api_key}", "User-Agent": self.get_user_agent(), } formatted_data = _format_edenai_messages(messages=messages) payload: Dict[str, Any] = { "providers": self.provider, "max_tokens": self.max_tokens, "temperature": self.temperature, "fallback_providers": self.fallback_providers, **formatted_data, **kwargs, } payload = {k: v for k, v in payload.items() if v is not None} if self.model is not None: payload["settings"] = {self.provider: self.model} request = Requests(headers=headers) response = request.post(url=url, data=payload) response.raise_for_status() data = response.json() provider_response = data[self.provider] if self.fallback_providers: fallback_response = data.get(self.fallback_providers) if fallback_response: provider_response = fallback_response if provider_response.get("status") == "fail": err_msg = provider_response.get("error", {}).get("message") raise Exception(err_msg) return ChatResult( generations=[ ChatGeneration( message=AIMessage(content=provider_response["generated_text"]) ) ], llm_output=data, ) async def _agenerate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: if self.streaming: stream_iter = self._astream( messages, stop=stop, run_manager=run_manager, **kwargs ) return await agenerate_from_stream(stream_iter) url = f"{self.edenai_api_url}/text/chat" headers = { "Authorization": f"Bearer {self._api_key}", "User-Agent": self.get_user_agent(), } formatted_data = _format_edenai_messages(messages=messages) payload: Dict[str, Any] = { "providers": self.provider, "max_tokens": self.max_tokens, "temperature": self.temperature, "fallback_providers": self.fallback_providers, **formatted_data, **kwargs, } payload = {k: v for k, v in payload.items() if v is not None} if self.model is not None: payload["settings"] = {self.provider: self.model} async with ClientSession() as session: async with session.post(url, json=payload, headers=headers) as response: response.raise_for_status() data = await response.json() provider_response = data[self.provider] if self.fallback_providers: fallback_response = data.get(self.fallback_providers) if fallback_response: provider_response = fallback_response if provider_response.get("status") == "fail": err_msg = provider_response.get("error", {}).get("message") raise Exception(err_msg) return ChatResult( generations=[ ChatGeneration( message=AIMessage( content=provider_response["generated_text"] ) ) ], llm_output=data, )