"""Wrapper around Prem's Chat API."""from__future__importannotationsimportloggingimportwarningsfromtypingimport(TYPE_CHECKING,Any,Callable,Dict,Iterator,List,Optional,Sequence,Tuple,Type,Union,)fromlangchain_core.callbacksimport(CallbackManagerForLLMRun,)fromlangchain_core.language_modelsimportLanguageModelInputfromlangchain_core.language_models.chat_modelsimportBaseChatModelfromlangchain_core.language_models.llmsimportcreate_base_retry_decoratorfromlangchain_core.messagesimport(AIMessage,AIMessageChunk,BaseMessage,BaseMessageChunk,ChatMessage,ChatMessageChunk,HumanMessage,HumanMessageChunk,SystemMessage,SystemMessageChunk,ToolMessage,)fromlangchain_core.outputsimportChatGeneration,ChatGenerationChunk,ChatResultfromlangchain_core.runnablesimportRunnablefromlangchain_core.toolsimportBaseToolfromlangchain_core.utilsimportget_from_dict_or_env,pre_initfromlangchain_core.utils.function_callingimportconvert_to_openai_toolfrompydanticimport(BaseModel,ConfigDict,Field,SecretStr,)ifTYPE_CHECKING:frompremai.api.chat_completions.v1_chat_completions_createimport(ChatCompletionResponseStream,)frompremai.models.chat_completion_responseimportChatCompletionResponselogger=logging.getLogger(__name__)TOOL_PROMPT_HEADER="""Given the set of tools you used and the response, provide the final answer\n"""INTERMEDIATE_TOOL_RESULT_TEMPLATE="""{json}"""SINGLE_TOOL_PROMPT_TEMPLATE="""tool id: {tool_id}tool_response: {tool_response}"""
[docs]classChatPremAPIError(Exception):"""Error with the `PremAI` API."""
def_truncate_at_stop_tokens(text:str,stop:Optional[List[str]],)->str:"""Truncates text at the earliest stop token found."""ifstopisNone:returntextforstop_tokeninstop:stop_token_idx=text.find(stop_token)ifstop_token_idx!=-1:text=text[:stop_token_idx]returntextdef_response_to_result(response:ChatCompletionResponse,stop:Optional[List[str]],)->ChatResult:"""Converts a Prem API response into a LangChain result"""ifnotresponse.choices:raiseChatPremAPIError("ChatResponse must have at least one candidate")generations:List[ChatGeneration]=[]forchoiceinresponse.choices:role=choice.message.roleifroleisNone:raiseChatPremAPIError(f"ChatResponse {choice} must have a role.")# If content is None then it will be replaced by ""content=_truncate_at_stop_tokens(text=choice.message.contentor"",stop=stop)ifcontentisNone:raiseChatPremAPIError(f"ChatResponse must have a content: {content}")ifrole=="assistant":tool_calls=choice.message["tool_calls"]iftool_callsisNone:tools=[]else:tools=[{"id":tool_call["id"],"name":tool_call["function"]["name"],"args":tool_call["function"]["arguments"],}fortool_callintool_calls]generations.append(ChatGeneration(text=content,message=AIMessage(content=content,tool_calls=tools)))elifrole=="user":generations.append(ChatGeneration(text=content,message=HumanMessage(content=content)))else:generations.append(ChatGeneration(text=content,message=ChatMessage(role=role,content=content)))ifresponse.document_chunksisnotNone:returnChatResult(generations=generations,llm_output={"document_chunks":[chunk.to_dict()forchunkinresponse.document_chunks]},)else:returnChatResult(generations=generations,llm_output={"document_chunks":None})def_convert_delta_response_to_message_chunk(response:ChatCompletionResponseStream,default_class:Type[BaseMessageChunk])->Tuple[Union[BaseMessageChunk,HumanMessageChunk,AIMessageChunk,SystemMessageChunk],Optional[str],]:"""Converts delta response to message chunk"""_delta=response.choices[0].delta# type: ignorerole=_delta.get("role","")# type: ignorecontent=_delta.get("content","")# type: ignoreadditional_kwargs:Dict={}finish_reasons:Optional[str]=response.choices[0].finish_reasonifrole=="user"ordefault_class==HumanMessageChunk:returnHumanMessageChunk(content=content),finish_reasonselifrole=="assistant"ordefault_class==AIMessageChunk:return(AIMessageChunk(content=content,additional_kwargs=additional_kwargs),finish_reasons,)elifrole=="system"ordefault_class==SystemMessageChunk:returnSystemMessageChunk(content=content),finish_reasonselifroleordefault_class==ChatMessageChunk:returnChatMessageChunk(content=content,role=role),finish_reasonselse:returndefault_class(content=content),finish_reasons# type: ignore[call-arg]def_messages_to_prompt_dict(input_messages:List[BaseMessage],template_id:Optional[str]=None,)->Tuple[Optional[str],List[Dict[str,Any]]]:"""Converts a list of LangChain Messages into a simple dict which is the message structure in Prem"""system_prompt:Optional[str]=Noneexamples_and_messages:List[Dict[str,Any]]=[]forinput_msgininput_messages:ifisinstance(input_msg,SystemMessage):system_prompt=str(input_msg.content)elifisinstance(input_msg,HumanMessage):iftemplate_idisNone:examples_and_messages.append({"role":"user","content":str(input_msg.content)})else:params:Dict[str,str]={}assert(input_msg.idisnotNone)and(input_msg.id!=""),ValueError("When using prompt template there should be id associated ","with each HumanMessage",)params[str(input_msg.id)]=str(input_msg.content)examples_and_messages.append({"role":"user","template_id":template_id,"params":params})elifisinstance(input_msg,AIMessage):ifinput_msg.tool_callsisNoneorlen(input_msg.tool_calls)==0:examples_and_messages.append({"role":"assistant","content":str(input_msg.content)})else:ai_msg_to_json={"id":input_msg.id,"content":input_msg.content,"response_metadata":input_msg.response_metadata,"tool_calls":input_msg.tool_calls,}examples_and_messages.append({"role":"assistant","content":INTERMEDIATE_TOOL_RESULT_TEMPLATE.format(json=ai_msg_to_json,),})elifisinstance(input_msg,ToolMessage):passelse:raiseChatPremAPIError("No such role explicitly exists")# do a seperate search for tool callstool_prompt=""forinput_msgininput_messages:ifisinstance(input_msg,ToolMessage):tool_id=input_msg.tool_call_idtool_result=input_msg.contenttool_prompt+=SINGLE_TOOL_PROMPT_TEMPLATE.format(tool_id=tool_id,tool_response=tool_result)iftool_prompt!="":prompt=TOOL_PROMPT_HEADERprompt+=tool_promptexamples_and_messages.append({"role":"user","content":prompt})returnsystem_prompt,examples_and_messages
[docs]classChatPremAI(BaseChatModel,BaseModel):"""PremAI Chat models. To use, you will need to have an API key. You can find your existing API Key or generate a new one here: https://app.premai.io/api_keys/ """# TODO: Need to add the default parameters through prem-sdk hereproject_id:int"""The project ID in which the experiments or deployments are carried out. You can find all your projects here: https://app.premai.io/projects/"""premai_api_key:Optional[SecretStr]=Field(default=None,alias="api_key")"""Prem AI API Key. Get it here: https://app.premai.io/api_keys/"""model:Optional[str]=Field(default=None,alias="model_name")"""Name of the model. This is an optional parameter. The default model is the one deployed from Prem's LaunchPad: https://app.premai.io/projects/8/launchpad If model name is other than default model then it will override the calls from the model deployed from launchpad."""session_id:Optional[str]=None"""The ID of the session to use. It helps to track the chat history."""temperature:Optional[float]=Field(default=None)"""Model temperature. Value should be >= 0 and <= 1.0"""top_p:Optional[float]=None"""top_p adjusts the number of choices for each predicted tokens based on cumulative probabilities. Value should be ranging between 0.0 and 1.0. """max_tokens:Optional[int]=Field(default=None)"""The maximum number of tokens to generate"""max_retries:int=Field(default=1)"""Max number of retries to call the API"""system_prompt:Optional[str]="""""Acts like a default instruction that helps the LLM act or generate in a specific way.This is an Optional Parameter. By default the system prompt would be using Prem's Launchpad models system prompt. Changing the system prompt would override the default system prompt. """repositories:Optional[dict]=None"""Add valid repository ids. This will be overriding existing connected repositories (if any) and will use RAG with the connected repos. """streaming:Optional[bool]=False"""Whether to stream the responses or not."""client:Any=Nonemodel_config=ConfigDict(populate_by_name=True,arbitrary_types_allowed=True,extra="forbid",)
[docs]@pre_initdefvalidate_environments(cls,values:Dict)->Dict:"""Validate that the package is installed and that the API token is valid"""try:frompremaiimportPremexceptImportErroraserror:raiseImportError("Could not import Prem Python package.""Please install it with: `pip install premai`")fromerrortry:premai_api_key:Union[str,SecretStr]=get_from_dict_or_env(values,"premai_api_key","PREMAI_API_KEY")values["client"]=Prem(api_key=premai_api_keyifisinstance(premai_api_key,str)elsepremai_api_key._secret_value)exceptExceptionaserror:raiseValueError("Your API Key is incorrect. Please try again.")fromerrorreturnvalues
@propertydef_llm_type(self)->str:return"premai"@propertydef_default_params(self)->Dict[str,Any]:return{"model":self.model,"system_prompt":self.system_prompt,"temperature":self.temperature,"max_tokens":self.max_tokens,"repositories":self.repositories,}def_get_all_kwargs(self,**kwargs:Any)->Dict[str,Any]:kwargs_to_ignore=["top_p","frequency_penalty","presence_penalty","logit_bias","stop","seed",]keys_to_remove=[]forkeyinkwargs:ifkeyinkwargs_to_ignore:warnings.warn(f"WARNING: Parameter {key} is not supported in kwargs.")keys_to_remove.append(key)forkeyinkeys_to_remove:kwargs.pop(key)all_kwargs={**self._default_params,**kwargs}forkeyinlist(self._default_params.keys()):ifall_kwargs.get(key)isNoneorall_kwargs.get(key)=="":all_kwargs.pop(key,None)returnall_kwargsdef_generate(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->ChatResult:if"template_id"inkwargs:system_prompt,messages_to_pass=_messages_to_prompt_dict(messages,template_id=kwargs["template_id"])else:system_prompt,messages_to_pass=_messages_to_prompt_dict(messages)# type: ignoreifsystem_promptisnotNoneandsystem_prompt!="":kwargs["system_prompt"]=system_promptall_kwargs=self._get_all_kwargs(**kwargs)response=chat_with_retry(self,project_id=self.project_id,messages=messages_to_pass,stream=False,run_manager=run_manager,**all_kwargs,)return_response_to_result(response=response,stop=stop)def_stream(self,messages:List[BaseMessage],stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[ChatGenerationChunk]:if"template_id"inkwargs:system_prompt,messages_to_pass=_messages_to_prompt_dict(messages,template_id=kwargs["template_id"])# type: ignoreelse:system_prompt,messages_to_pass=_messages_to_prompt_dict(messages)# type: ignoreifstopisnotNone:logger.warning("stop is not supported in langchain streaming")if"system_prompt"notinkwargs:ifsystem_promptisnotNoneandsystem_prompt!="":kwargs["system_prompt"]=system_promptall_kwargs=self._get_all_kwargs(**kwargs)default_chunk_class=AIMessageChunkforstreamed_responseinchat_with_retry(self,project_id=self.project_id,messages=messages_to_pass,stream=True,run_manager=run_manager,**all_kwargs,):try:chunk,finish_reason=_convert_delta_response_to_message_chunk(response=streamed_response,default_class=default_chunk_class)generation_info=(dict(finish_reason=finish_reason)iffinish_reasonisnotNoneelseNone)cg_chunk=ChatGenerationChunk(message=chunk,generation_info=generation_info)ifrun_manager:run_manager.on_llm_new_token(cg_chunk.text,chunk=cg_chunk)yieldcg_chunkexceptExceptionas_:continue
[docs]defcreate_prem_retry_decorator(llm:ChatPremAI,*,max_retries:int=1,run_manager:Optional[Union[CallbackManagerForLLMRun]]=None,)->Callable[[Any],Any]:"""Create a retry decorator for PremAI API errors."""importpremai.modelserrors=[premai.models.api_response_validation_error.APIResponseValidationError,premai.models.conflict_error.ConflictError,premai.models.model_not_found_error.ModelNotFoundError,premai.models.permission_denied_error.PermissionDeniedError,premai.models.provider_api_connection_error.ProviderAPIConnectionError,premai.models.provider_api_status_error.ProviderAPIStatusError,premai.models.provider_api_timeout_error.ProviderAPITimeoutError,premai.models.provider_internal_server_error.ProviderInternalServerError,premai.models.provider_not_found_error.ProviderNotFoundError,premai.models.rate_limit_error.RateLimitError,premai.models.unprocessable_entity_error.UnprocessableEntityError,premai.models.validation_error.ValidationError,]decorator=create_base_retry_decorator(error_types=errors,max_retries=max_retries,run_manager=run_manager)returndecorator
[docs]defchat_with_retry(llm:ChatPremAI,project_id:int,messages:List[dict],stream:bool=False,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Any:"""Using tenacity for retry in completion call"""retry_decorator=create_prem_retry_decorator(llm,max_retries=llm.max_retries,run_manager=run_manager)@retry_decoratordef_completion_with_retry(project_id:int,messages:List[dict],stream:Optional[bool]=False,**kwargs:Any,)->Any:response=llm.client.chat.completions.create(project_id=project_id,messages=messages,stream=stream,**kwargs,)returnresponsereturn_completion_with_retry(project_id=project_id,messages=messages,stream=stream,**kwargs,)