Source code for langchain_community.chat_models.premai

"""封装了Prem的聊天API。"""

from __future__ import annotations

import logging
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    Iterator,
    List,
    Optional,
    Tuple,
    Type,
    Union,
)

from langchain_core.callbacks import (
    CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.messages import (
    AIMessage,
    AIMessageChunk,
    BaseMessage,
    BaseMessageChunk,
    ChatMessage,
    ChatMessageChunk,
    HumanMessage,
    HumanMessageChunk,
    SystemMessage,
    SystemMessageChunk,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import (
    BaseModel,
    Extra,
    Field,
    SecretStr,
    root_validator,
)
from langchain_core.utils import get_from_dict_or_env

if TYPE_CHECKING:
    from premai.api.chat_completions.v1_chat_completions_create import (
        ChatCompletionResponseStream,
    )
    from premai.models.chat_completion_response import ChatCompletionResponse

logger = logging.getLogger(__name__)


[docs]class ChatPremAPIError(Exception): """`PremAI` API存在错误。"""
def _truncate_at_stop_tokens( text: str, stop: Optional[List[str]], ) -> str: """在找到最早的停止标记处截断文本。""" if stop is None: return text for stop_token in stop: stop_token_idx = text.find(stop_token) if stop_token_idx != -1: text = text[:stop_token_idx] return text def _response_to_result( response: ChatCompletionResponse, stop: Optional[List[str]], ) -> ChatResult: """将Prem API响应转换为LangChain结果""" if not response.choices: raise ChatPremAPIError("ChatResponse must have at least one candidate") generations: List[ChatGeneration] = [] for choice in response.choices: role = choice.message.role if role is None: raise ChatPremAPIError(f"ChatResponse {choice} must have a role.") # If content is None then it will be replaced by "" content = _truncate_at_stop_tokens(text=choice.message.content or "", stop=stop) if content is None: raise ChatPremAPIError(f"ChatResponse must have a content: {content}") if role == "assistant": generations.append( ChatGeneration(text=content, message=AIMessage(content=content)) ) elif role == "user": generations.append( ChatGeneration(text=content, message=HumanMessage(content=content)) ) else: generations.append( ChatGeneration( text=content, message=ChatMessage(role=role, content=content) ) ) return ChatResult(generations=generations) def _convert_delta_response_to_message_chunk( response: ChatCompletionResponseStream, default_class: Type[BaseMessageChunk] ) -> Tuple[ Union[BaseMessageChunk, HumanMessageChunk, AIMessageChunk, SystemMessageChunk], Optional[str], ]: """将增量响应转换为消息块""" _delta = response.choices[0].delta # type: ignore role = _delta.get("role", "") # type: ignore content = _delta.get("content", "") # type: ignore additional_kwargs: Dict = {} if role is None or role == "": raise ChatPremAPIError("Role can not be None. Please check the response") finish_reasons: Optional[str] = response.choices[0].finish_reason if role == "user" or default_class == HumanMessageChunk: return HumanMessageChunk(content=content), finish_reasons elif role == "assistant" or default_class == AIMessageChunk: return ( AIMessageChunk(content=content, additional_kwargs=additional_kwargs), finish_reasons, ) elif role == "system" or default_class == SystemMessageChunk: return SystemMessageChunk(content=content), finish_reasons elif role or default_class == ChatMessageChunk: return ChatMessageChunk(content=content, role=role), finish_reasons else: return default_class(content=content), finish_reasons # type: ignore[call-arg] def _messages_to_prompt_dict( input_messages: List[BaseMessage], ) -> Tuple[Optional[str], List[Dict[str, str]]]: """将LangChain消息列表转换为Prem中的消息结构的简单字典。 """ system_prompt: Optional[str] = None examples_and_messages: List[Dict[str, str]] = [] for input_msg in input_messages: if isinstance(input_msg, SystemMessage): system_prompt = str(input_msg.content) elif isinstance(input_msg, HumanMessage): examples_and_messages.append( {"role": "user", "content": str(input_msg.content)} ) elif isinstance(input_msg, AIMessage): examples_and_messages.append( {"role": "assistant", "content": str(input_msg.content)} ) else: raise ChatPremAPIError("No such role explicitly exists") return system_prompt, examples_and_messages
[docs]class ChatPremAI(BaseChatModel, BaseModel): """PremAI聊天模型。 要使用,您需要拥有一个API密钥。您可以在这里找到您现有的API密钥 或生成一个新的:https://app.premai.io/api_keys/""" # TODO: Need to add the default parameters through prem-sdk here project_id: int """实验或部署所在的项目ID。 您可以在此处找到所有您的项目:https://app.premai.io/projects/""" premai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") """Prem AI API密钥。在此处获取:https://app.premai.io/api_keys/""" model: Optional[str] = Field(default=None, alias="model_name") """模型的名称。这是一个可选参数。 默认模型是从Prem的LaunchPad部署的模型:https://app.premai.io/projects/8/launchpad 如果模型名称不是默认模型,则会覆盖从launchpad部署的模型的调用。""" session_id: Optional[str] = None """要使用的会话ID。它有助于跟踪聊天历史记录。""" temperature: Optional[float] = None """模型温度。值应大于等于0且小于等于1.0。""" top_p: Optional[float] = None """top_p根据累积概率调整每个预测标记的选择数量。值应在0.0和1.0之间。""" max_tokens: Optional[int] = None """生成的令牌的最大数量""" max_retries: int = 1 """调用API的最大重试次数""" system_prompt: Optional[str] = "" """表现为一个默认指令,帮助LLM以特定方式行动或生成。这是一个可选参数。默认情况下,系统提示将使用Prem的Launchpad模型系统提示。更改系统提示将覆盖默认系统提示。""" streaming: Optional[bool] = False """是否要流式传输响应。""" tools: Optional[Dict[str, Any]] = None """模型可能调用的工具列表。目前,仅支持函数作为工具。""" frequency_penalty: Optional[float] = None """数字在-2.0和2.0之间。正值根据新令牌进行惩罚。""" presence_penalty: Optional[float] = None """介于-2.0和2.0之间的数字。正值根据它们在文本中出现的频率对新标记进行惩罚。""" logit_bias: Optional[dict] = None """将令牌映射到从-100到100的关联偏差值的JSON对象。""" stop: Optional[Union[str, List[str]]] = None """最多生成4个序列,API将停止生成更多的令牌。""" seed: Optional[int] = None """这个功能处于Beta阶段。如果指定,我们的系统将尽最大努力进行确定性采样。""" client: Any class Config: """此pydantic对象的配置。""" extra = Extra.forbid allow_population_by_field_name = True arbitrary_types_allowed = True @root_validator() def validate_environments(cls, values: Dict) -> Dict: """验证包是否已安装并且API令牌是否有效""" try: from premai import Prem except ImportError as error: raise ImportError( "Could not import Prem Python package." "Please install it with: `pip install premai`" ) from error try: premai_api_key = get_from_dict_or_env( values, "premai_api_key", "PREMAI_API_KEY" ) values["client"] = Prem(api_key=premai_api_key) except Exception as error: raise ValueError("Your API Key is incorrect. Please try again.") from error return values @property def _llm_type(self) -> str: return "premai" @property def _default_params(self) -> Dict[str, Any]: # FIXME: n and stop is not supported, so hardcoding to current default value return { "model": self.model, "system_prompt": self.system_prompt, "top_p": self.top_p, "temperature": self.temperature, "logit_bias": self.logit_bias, "max_tokens": self.max_tokens, "presence_penalty": self.presence_penalty, "frequency_penalty": self.frequency_penalty, "seed": self.seed, "stop": None, } def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]: all_kwargs = {**self._default_params, **kwargs} for key in list(self._default_params.keys()): if all_kwargs.get(key) is None or all_kwargs.get(key) == "": all_kwargs.pop(key, None) return all_kwargs def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: system_prompt, messages_to_pass = _messages_to_prompt_dict(messages) # type: ignore kwargs["stop"] = stop if system_prompt is not None and system_prompt != "": kwargs["system_prompt"] = system_prompt all_kwargs = self._get_all_kwargs(**kwargs) response = chat_with_retry( self, project_id=self.project_id, messages=messages_to_pass, stream=False, run_manager=run_manager, **all_kwargs, ) return _response_to_result(response=response, stop=stop) def _stream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: system_prompt, messages_to_pass = _messages_to_prompt_dict(messages) kwargs["stop"] = stop if "system_prompt" not in kwargs: if system_prompt is not None and system_prompt != "": kwargs["system_prompt"] = system_prompt all_kwargs = self._get_all_kwargs(**kwargs) default_chunk_class = AIMessageChunk for streamed_response in chat_with_retry( self, project_id=self.project_id, messages=messages_to_pass, stream=True, run_manager=run_manager, **all_kwargs, ): try: chunk, finish_reason = _convert_delta_response_to_message_chunk( response=streamed_response, default_class=default_chunk_class ) generation_info = ( dict(finish_reason=finish_reason) if finish_reason is not None else None ) cg_chunk = ChatGenerationChunk( message=chunk, generation_info=generation_info ) if run_manager: run_manager.on_llm_new_token(cg_chunk.text, chunk=cg_chunk) yield cg_chunk except Exception as _: continue
[docs]def create_prem_retry_decorator( llm: ChatPremAI, *, max_retries: int = 1, run_manager: Optional[Union[CallbackManagerForLLMRun]] = None, ) -> Callable[[Any], Any]: """为PremAI API错误创建一个重试装饰器。""" import premai.models errors = [ premai.models.api_response_validation_error.APIResponseValidationError, premai.models.conflict_error.ConflictError, premai.models.model_not_found_error.ModelNotFoundError, premai.models.permission_denied_error.PermissionDeniedError, premai.models.provider_api_connection_error.ProviderAPIConnectionError, premai.models.provider_api_status_error.ProviderAPIStatusError, premai.models.provider_api_timeout_error.ProviderAPITimeoutError, premai.models.provider_internal_server_error.ProviderInternalServerError, premai.models.provider_not_found_error.ProviderNotFoundError, premai.models.rate_limit_error.RateLimitError, premai.models.unprocessable_entity_error.UnprocessableEntityError, premai.models.validation_error.ValidationError, ] decorator = create_base_retry_decorator( error_types=errors, max_retries=max_retries, run_manager=run_manager ) return decorator
[docs]def chat_with_retry( llm: ChatPremAI, project_id: int, messages: List[dict], stream: bool = False, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Any: """使用tenacity进行重试以完成调用""" retry_decorator = create_prem_retry_decorator( llm, max_retries=llm.max_retries, run_manager=run_manager ) @retry_decorator def _completion_with_retry( project_id: int, messages: List[dict], stream: Optional[bool] = False, **kwargs: Any, ) -> Any: response = llm.client.chat.completions.create( project_id=project_id, messages=messages, stream=stream, **kwargs, ) return response return _completion_with_retry( project_id=project_id, messages=messages, stream=stream, **kwargs, )