"""封装了Prem的聊天API。"""
from __future__ import annotations
import logging
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterator,
List,
Optional,
Tuple,
Type,
Union,
)
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import (
BaseModel,
Extra,
Field,
SecretStr,
root_validator,
)
from langchain_core.utils import get_from_dict_or_env
if TYPE_CHECKING:
from premai.api.chat_completions.v1_chat_completions_create import (
ChatCompletionResponseStream,
)
from premai.models.chat_completion_response import ChatCompletionResponse
logger = logging.getLogger(__name__)
[docs]class ChatPremAPIError(Exception):
"""`PremAI` API存在错误。"""
def _truncate_at_stop_tokens(
text: str,
stop: Optional[List[str]],
) -> str:
"""在找到最早的停止标记处截断文本。"""
if stop is None:
return text
for stop_token in stop:
stop_token_idx = text.find(stop_token)
if stop_token_idx != -1:
text = text[:stop_token_idx]
return text
def _response_to_result(
response: ChatCompletionResponse,
stop: Optional[List[str]],
) -> ChatResult:
"""将Prem API响应转换为LangChain结果"""
if not response.choices:
raise ChatPremAPIError("ChatResponse must have at least one candidate")
generations: List[ChatGeneration] = []
for choice in response.choices:
role = choice.message.role
if role is None:
raise ChatPremAPIError(f"ChatResponse {choice} must have a role.")
# If content is None then it will be replaced by ""
content = _truncate_at_stop_tokens(text=choice.message.content or "", stop=stop)
if content is None:
raise ChatPremAPIError(f"ChatResponse must have a content: {content}")
if role == "assistant":
generations.append(
ChatGeneration(text=content, message=AIMessage(content=content))
)
elif role == "user":
generations.append(
ChatGeneration(text=content, message=HumanMessage(content=content))
)
else:
generations.append(
ChatGeneration(
text=content, message=ChatMessage(role=role, content=content)
)
)
return ChatResult(generations=generations)
def _convert_delta_response_to_message_chunk(
response: ChatCompletionResponseStream, default_class: Type[BaseMessageChunk]
) -> Tuple[
Union[BaseMessageChunk, HumanMessageChunk, AIMessageChunk, SystemMessageChunk],
Optional[str],
]:
"""将增量响应转换为消息块"""
_delta = response.choices[0].delta # type: ignore
role = _delta.get("role", "") # type: ignore
content = _delta.get("content", "") # type: ignore
additional_kwargs: Dict = {}
if role is None or role == "":
raise ChatPremAPIError("Role can not be None. Please check the response")
finish_reasons: Optional[str] = response.choices[0].finish_reason
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content), finish_reasons
elif role == "assistant" or default_class == AIMessageChunk:
return (
AIMessageChunk(content=content, additional_kwargs=additional_kwargs),
finish_reasons,
)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content), finish_reasons
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role), finish_reasons
else:
return default_class(content=content), finish_reasons # type: ignore[call-arg]
def _messages_to_prompt_dict(
input_messages: List[BaseMessage],
) -> Tuple[Optional[str], List[Dict[str, str]]]:
"""将LangChain消息列表转换为Prem中的消息结构的简单字典。
"""
system_prompt: Optional[str] = None
examples_and_messages: List[Dict[str, str]] = []
for input_msg in input_messages:
if isinstance(input_msg, SystemMessage):
system_prompt = str(input_msg.content)
elif isinstance(input_msg, HumanMessage):
examples_and_messages.append(
{"role": "user", "content": str(input_msg.content)}
)
elif isinstance(input_msg, AIMessage):
examples_and_messages.append(
{"role": "assistant", "content": str(input_msg.content)}
)
else:
raise ChatPremAPIError("No such role explicitly exists")
return system_prompt, examples_and_messages
[docs]class ChatPremAI(BaseChatModel, BaseModel):
"""PremAI聊天模型。
要使用,您需要拥有一个API密钥。您可以在这里找到您现有的API密钥
或生成一个新的:https://app.premai.io/api_keys/"""
# TODO: Need to add the default parameters through prem-sdk here
project_id: int
"""实验或部署所在的项目ID。
您可以在此处找到所有您的项目:https://app.premai.io/projects/"""
premai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""Prem AI API密钥。在此处获取:https://app.premai.io/api_keys/"""
model: Optional[str] = Field(default=None, alias="model_name")
"""模型的名称。这是一个可选参数。
默认模型是从Prem的LaunchPad部署的模型:https://app.premai.io/projects/8/launchpad
如果模型名称不是默认模型,则会覆盖从launchpad部署的模型的调用。"""
session_id: Optional[str] = None
"""要使用的会话ID。它有助于跟踪聊天历史记录。"""
temperature: Optional[float] = None
"""模型温度。值应大于等于0且小于等于1.0。"""
top_p: Optional[float] = None
"""top_p根据累积概率调整每个预测标记的选择数量。值应在0.0和1.0之间。"""
max_tokens: Optional[int] = None
"""生成的令牌的最大数量"""
max_retries: int = 1
"""调用API的最大重试次数"""
system_prompt: Optional[str] = ""
"""表现为一个默认指令,帮助LLM以特定方式行动或生成。这是一个可选参数。默认情况下,系统提示将使用Prem的Launchpad模型系统提示。更改系统提示将覆盖默认系统提示。"""
streaming: Optional[bool] = False
"""是否要流式传输响应。"""
tools: Optional[Dict[str, Any]] = None
"""模型可能调用的工具列表。目前,仅支持函数作为工具。"""
frequency_penalty: Optional[float] = None
"""数字在-2.0和2.0之间。正值根据新令牌进行惩罚。"""
presence_penalty: Optional[float] = None
"""介于-2.0和2.0之间的数字。正值根据它们在文本中出现的频率对新标记进行惩罚。"""
logit_bias: Optional[dict] = None
"""将令牌映射到从-100到100的关联偏差值的JSON对象。"""
stop: Optional[Union[str, List[str]]] = None
"""最多生成4个序列,API将停止生成更多的令牌。"""
seed: Optional[int] = None
"""这个功能处于Beta阶段。如果指定,我们的系统将尽最大努力进行确定性采样。"""
client: Any
class Config:
"""此pydantic对象的配置。"""
extra = Extra.forbid
allow_population_by_field_name = True
arbitrary_types_allowed = True
@root_validator()
def validate_environments(cls, values: Dict) -> Dict:
"""验证包是否已安装并且API令牌是否有效"""
try:
from premai import Prem
except ImportError as error:
raise ImportError(
"Could not import Prem Python package."
"Please install it with: `pip install premai`"
) from error
try:
premai_api_key = get_from_dict_or_env(
values, "premai_api_key", "PREMAI_API_KEY"
)
values["client"] = Prem(api_key=premai_api_key)
except Exception as error:
raise ValueError("Your API Key is incorrect. Please try again.") from error
return values
@property
def _llm_type(self) -> str:
return "premai"
@property
def _default_params(self) -> Dict[str, Any]:
# FIXME: n and stop is not supported, so hardcoding to current default value
return {
"model": self.model,
"system_prompt": self.system_prompt,
"top_p": self.top_p,
"temperature": self.temperature,
"logit_bias": self.logit_bias,
"max_tokens": self.max_tokens,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"seed": self.seed,
"stop": None,
}
def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
all_kwargs = {**self._default_params, **kwargs}
for key in list(self._default_params.keys()):
if all_kwargs.get(key) is None or all_kwargs.get(key) == "":
all_kwargs.pop(key, None)
return all_kwargs
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages) # type: ignore
kwargs["stop"] = stop
if system_prompt is not None and system_prompt != "":
kwargs["system_prompt"] = system_prompt
all_kwargs = self._get_all_kwargs(**kwargs)
response = chat_with_retry(
self,
project_id=self.project_id,
messages=messages_to_pass,
stream=False,
run_manager=run_manager,
**all_kwargs,
)
return _response_to_result(response=response, stop=stop)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages)
kwargs["stop"] = stop
if "system_prompt" not in kwargs:
if system_prompt is not None and system_prompt != "":
kwargs["system_prompt"] = system_prompt
all_kwargs = self._get_all_kwargs(**kwargs)
default_chunk_class = AIMessageChunk
for streamed_response in chat_with_retry(
self,
project_id=self.project_id,
messages=messages_to_pass,
stream=True,
run_manager=run_manager,
**all_kwargs,
):
try:
chunk, finish_reason = _convert_delta_response_to_message_chunk(
response=streamed_response, default_class=default_chunk_class
)
generation_info = (
dict(finish_reason=finish_reason)
if finish_reason is not None
else None
)
cg_chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
if run_manager:
run_manager.on_llm_new_token(cg_chunk.text, chunk=cg_chunk)
yield cg_chunk
except Exception as _:
continue
[docs]def create_prem_retry_decorator(
llm: ChatPremAI,
*,
max_retries: int = 1,
run_manager: Optional[Union[CallbackManagerForLLMRun]] = None,
) -> Callable[[Any], Any]:
"""为PremAI API错误创建一个重试装饰器。"""
import premai.models
errors = [
premai.models.api_response_validation_error.APIResponseValidationError,
premai.models.conflict_error.ConflictError,
premai.models.model_not_found_error.ModelNotFoundError,
premai.models.permission_denied_error.PermissionDeniedError,
premai.models.provider_api_connection_error.ProviderAPIConnectionError,
premai.models.provider_api_status_error.ProviderAPIStatusError,
premai.models.provider_api_timeout_error.ProviderAPITimeoutError,
premai.models.provider_internal_server_error.ProviderInternalServerError,
premai.models.provider_not_found_error.ProviderNotFoundError,
premai.models.rate_limit_error.RateLimitError,
premai.models.unprocessable_entity_error.UnprocessableEntityError,
premai.models.validation_error.ValidationError,
]
decorator = create_base_retry_decorator(
error_types=errors, max_retries=max_retries, run_manager=run_manager
)
return decorator
[docs]def chat_with_retry(
llm: ChatPremAI,
project_id: int,
messages: List[dict],
stream: bool = False,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""使用tenacity进行重试以完成调用"""
retry_decorator = create_prem_retry_decorator(
llm, max_retries=llm.max_retries, run_manager=run_manager
)
@retry_decorator
def _completion_with_retry(
project_id: int,
messages: List[dict],
stream: Optional[bool] = False,
**kwargs: Any,
) -> Any:
response = llm.client.chat.completions.create(
project_id=project_id,
messages=messages,
stream=stream,
**kwargs,
)
return response
return _completion_with_retry(
project_id=project_id,
messages=messages,
stream=stream,
**kwargs,
)