"""封装了谷歌的PaLM Chat API。"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, cast
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import (
ChatGeneration,
ChatResult,
)
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
if TYPE_CHECKING:
import google.generativeai as genai
logger = logging.getLogger(__name__)
[docs]class ChatGooglePalmError(Exception):
"""`Google PaLM` API存在错误。"""
def _truncate_at_stop_tokens(
text: str,
stop: Optional[List[str]],
) -> str:
"""在找到最早的停止标记处截断文本。"""
if stop is None:
return text
for stop_token in stop:
stop_token_idx = text.find(stop_token)
if stop_token_idx != -1:
text = text[:stop_token_idx]
return text
def _response_to_result(
response: genai.types.ChatResponse,
stop: Optional[List[str]],
) -> ChatResult:
"""将PaLM API响应转换为LangChain ChatResult。"""
if not response.candidates:
raise ChatGooglePalmError("ChatResponse must have at least one candidate.")
generations: List[ChatGeneration] = []
for candidate in response.candidates:
author = candidate.get("author")
if author is None:
raise ChatGooglePalmError(f"ChatResponse must have an author: {candidate}")
content = _truncate_at_stop_tokens(candidate.get("content", ""), stop)
if content is None:
raise ChatGooglePalmError(f"ChatResponse must have a content: {candidate}")
if author == "ai":
generations.append(
ChatGeneration(text=content, message=AIMessage(content=content))
)
elif author == "human":
generations.append(
ChatGeneration(
text=content,
message=HumanMessage(content=content),
)
)
else:
generations.append(
ChatGeneration(
text=content,
message=ChatMessage(role=author, content=content),
)
)
return ChatResult(generations=generations)
def _messages_to_prompt_dict(
input_messages: List[BaseMessage],
) -> genai.types.MessagePromptDict:
"""将LangChain消息列表转换为PaLM API MessagePrompt结构。"""
import google.generativeai as genai
context: str = ""
examples: List[genai.types.MessageDict] = []
messages: List[genai.types.MessageDict] = []
remaining = list(enumerate(input_messages))
while remaining:
index, input_message = remaining.pop(0)
if isinstance(input_message, SystemMessage):
if index != 0:
raise ChatGooglePalmError("System message must be first input message.")
context = cast(str, input_message.content)
elif isinstance(input_message, HumanMessage) and input_message.example:
if messages:
raise ChatGooglePalmError(
"Message examples must come before other messages."
)
_, next_input_message = remaining.pop(0)
if isinstance(next_input_message, AIMessage) and next_input_message.example:
examples.extend(
[
genai.types.MessageDict(
author="human", content=input_message.content
),
genai.types.MessageDict(
author="ai", content=next_input_message.content
),
]
)
else:
raise ChatGooglePalmError(
"Human example message must be immediately followed by an "
" AI example response."
)
elif isinstance(input_message, AIMessage) and input_message.example:
raise ChatGooglePalmError(
"AI example message must be immediately preceded by a Human "
"example message."
)
elif isinstance(input_message, AIMessage):
messages.append(
genai.types.MessageDict(author="ai", content=input_message.content)
)
elif isinstance(input_message, HumanMessage):
messages.append(
genai.types.MessageDict(author="human", content=input_message.content)
)
elif isinstance(input_message, ChatMessage):
messages.append(
genai.types.MessageDict(
author=input_message.role, content=input_message.content
)
)
else:
raise ChatGooglePalmError(
"Messages without an explicit role not supported by PaLM API."
)
return genai.types.MessagePromptDict(
context=context,
examples=examples,
messages=messages,
)
def _create_retry_decorator() -> Callable[[Any], Any]:
"""返回一个tenacity重试装饰器,预先配置为处理PaLM异常。"""
import google.api_core.exceptions
multiplier = 2
min_seconds = 1
max_seconds = 60
max_retries = 10
return retry(
reraise=True,
stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=multiplier, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(google.api_core.exceptions.ResourceExhausted)
| retry_if_exception_type(google.api_core.exceptions.ServiceUnavailable)
| retry_if_exception_type(google.api_core.exceptions.GoogleAPIError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
[docs]def chat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any:
"""使用tenacity来重试完成调用。"""
retry_decorator = _create_retry_decorator()
@retry_decorator
def _chat_with_retry(**kwargs: Any) -> Any:
return llm.client.chat(**kwargs)
return _chat_with_retry(**kwargs)
[docs]async def achat_with_retry(llm: ChatGooglePalm, **kwargs: Any) -> Any:
"""使用tenacity来重试异步完成调用。"""
retry_decorator = _create_retry_decorator()
@retry_decorator
async def _achat_with_retry(**kwargs: Any) -> Any:
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
return await llm.client.chat_async(**kwargs)
return await _achat_with_retry(**kwargs)
[docs]class ChatGooglePalm(BaseChatModel, BaseModel):
"""`Google PaLM` 聊天模型 API。
要使用,必须安装 google.generativeai Python 包,并且:
1. 设置 ``GOOGLE_API_KEY``` 环境变量为您的 API 密钥,或
2. 通过将 API 密钥传递给 ChatGoogle 构造函数的 google_api_key kwarg。
示例:
.. code-block:: python
from langchain_community.chat_models import ChatGooglePalm
chat = ChatGooglePalm()"""
client: Any #: :meta private:
model_name: str = "models/chat-bison-001"
"""要使用的模型名称。"""
google_api_key: Optional[SecretStr] = None
temperature: Optional[float] = None
"""使用这个温度进行推断。必须在闭区间[0.0, 1.0]内。"""
top_p: Optional[float] = None
"""使用核采样解码:考虑概率总和至少为top_p的最小标记集。必须在闭区间[0.0, 1.0]内。"""
top_k: Optional[int] = None
"""使用top-k抽样解码:考虑前k个最有可能的标记集合。
必须为正数。"""
n: int = 1
"""每个提示生成的聊天完成次数。请注意,如果生成了重复项,API可能不会返回完整的n个完成项。"""
@property
def lc_secrets(self) -> Dict[str, str]:
return {"google_api_key": "GOOGLE_API_KEY"}
[docs] @classmethod
def is_lc_serializable(self) -> bool:
return True
[docs] @classmethod
def get_lc_namespace(cls) -> List[str]:
"""获取langchain对象的命名空间。"""
return ["langchain", "chat_models", "google_palm"]
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""验证API密钥、Python包是否存在、温度、top_p和top_k。"""
google_api_key = convert_to_secret_str(
get_from_dict_or_env(values, "google_api_key", "GOOGLE_API_KEY")
)
try:
import google.generativeai as genai
genai.configure(api_key=google_api_key.get_secret_value())
except ImportError:
raise ChatGooglePalmError(
"Could not import google.generativeai python package. "
"Please install it with `pip install google-generativeai`"
)
values["client"] = genai
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")
if values["top_p"] is not None and not 0 <= values["top_p"] <= 1:
raise ValueError("top_p must be in the range [0.0, 1.0]")
if values["top_k"] is not None and values["top_k"] <= 0:
raise ValueError("top_k must be positive")
return values
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
prompt = _messages_to_prompt_dict(messages)
response: genai.types.ChatResponse = chat_with_retry(
self,
model=self.model_name,
prompt=prompt,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
candidate_count=self.n,
**kwargs,
)
return _response_to_result(response, stop)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
prompt = _messages_to_prompt_dict(messages)
response: genai.types.ChatResponse = await achat_with_retry(
self,
model=self.model_name,
prompt=prompt,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
candidate_count=self.n,
)
return _response_to_result(response, stop)
@property
def _identifying_params(self) -> Dict[str, Any]:
"""获取识别参数。"""
return {
"model_name": self.model_name,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"n": self.n,
}
@property
def _llm_type(self) -> str:
return "google-palm-chat"