Source code for langchain_community.chat_models.zhipuai

"""ZhipuAI聊天模型封装。"""

from __future__ import annotations

import json
import logging
import time
from collections.abc import AsyncIterator, Iterator
from contextlib import asynccontextmanager, contextmanager
from typing import Any, Dict, List, Optional, Tuple, Type, Union

from langchain_core.callbacks import (
    AsyncCallbackManagerForLLMRun,
    CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
    BaseChatModel,
    agenerate_from_stream,
    generate_from_stream,
)
from langchain_core.messages import (
    AIMessage,
    AIMessageChunk,
    BaseMessage,
    BaseMessageChunk,
    ChatMessage,
    ChatMessageChunk,
    HumanMessage,
    HumanMessageChunk,
    SystemMessage,
    SystemMessageChunk,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.utils import get_from_dict_or_env

logger = logging.getLogger(__name__)

API_TOKEN_TTL_SECONDS = 3 * 60
ZHIPUAI_API_BASE = "https://open.bigmodel.cn/api/paas/v4/chat/completions"


[docs]@contextmanager def connect_sse(client: Any, method: str, url: str, **kwargs: Any) -> Iterator: from httpx_sse import EventSource with client.stream(method, url, **kwargs) as response: yield EventSource(response)
[docs]@asynccontextmanager async def aconnect_sse( client: Any, method: str, url: str, **kwargs: Any ) -> AsyncIterator: from httpx_sse import EventSource async with client.stream(method, url, **kwargs) as response: yield EventSource(response)
def _get_jwt_token(api_key: str) -> str: """获取ZhipuAI API的JWT令牌,参见'https://open.bigmodel.cn/dev/api#nosdk'。 参数: api_key: ZhipuAI API的API密钥。 返回: JWT令牌。 """ import jwt try: id, secret = api_key.split(".") except ValueError as err: raise ValueError(f"Invalid API key: {api_key}") from err payload = { "api_key": id, "exp": int(round(time.time() * 1000)) + API_TOKEN_TTL_SECONDS * 1000, "timestamp": int(round(time.time() * 1000)), } return jwt.encode( payload, secret, algorithm="HS256", headers={"alg": "HS256", "sign_type": "SIGN"}, ) def _convert_dict_to_message(dct: Dict[str, Any]) -> BaseMessage: role = dct.get("role") content = dct.get("content", "") if role == "system": return SystemMessage(content=content) if role == "user": return HumanMessage(content=content) if role == "assistant": additional_kwargs = {} tool_calls = dct.get("tool_calls", None) if tool_calls is not None: additional_kwargs["tool_calls"] = tool_calls return AIMessage(content=content, additional_kwargs=additional_kwargs) return ChatMessage(role=role, content=content) # type: ignore[arg-type] def _convert_message_to_dict(message: BaseMessage) -> Dict[str, Any]: """将LangChain消息转换为字典。 参数: message:LangChain消息。 返回: 字典。 """ message_dict: Dict[str, Any] if isinstance(message, ChatMessage): message_dict = {"role": message.role, "content": message.content} elif isinstance(message, SystemMessage): message_dict = {"role": "system", "content": message.content} elif isinstance(message, HumanMessage): message_dict = {"role": "user", "content": message.content} elif isinstance(message, AIMessage): message_dict = {"role": "assistant", "content": message.content} else: raise TypeError(f"Got unknown type '{message.__class__.__name__}'.") return message_dict def _convert_delta_to_message_chunk( dct: Dict[str, Any], default_class: Type[BaseMessageChunk] ) -> BaseMessageChunk: role = dct.get("role") content = dct.get("content", "") additional_kwargs = {} tool_calls = dct.get("tool_call", None) if tool_calls is not None: additional_kwargs["tool_calls"] = tool_calls if role == "system" or default_class == SystemMessageChunk: return SystemMessageChunk(content=content) if role == "user" or default_class == HumanMessageChunk: return HumanMessageChunk(content=content) if role == "assistant" or default_class == AIMessageChunk: return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) if role or default_class == ChatMessageChunk: return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type] return default_class(content=content) # type: ignore[call-arg] def _truncate_params(payload: Dict[str, Any]) -> None: """将温度和top_p参数截断在[0.01, 0.99]之间。 ZhipuAI仅支持在开区间(0, 1)之间的温度/top_p,因此我们将它们截断为[0.01, 0.99]。 """ temperature = payload.get("temperature") top_p = payload.get("top_p") if temperature is not None: payload["temperature"] = max(0.01, min(0.99, temperature)) if top_p is not None: payload["top_p"] = max(0.01, min(0.99, top_p))
[docs]class ChatZhipuAI(BaseChatModel): """``ZhipuAI`` 大型语言聊天模型API。 要使用,您应该已安装``PyJWT`` python包。 示例: .. code-block:: python 从 langchain_community.chat_models 导入 ChatZhipuAI zhipuai_chat = ChatZhipuAI( temperature=0.5, api_key="your-api-key", model="glm-4" )""" @property def lc_secrets(self) -> Dict[str, str]: return {"zhipuai_api_key": "ZHIPUAI_API_KEY"}
[docs] @classmethod def get_lc_namespace(cls) -> List[str]: """获取langchain对象的命名空间。""" return ["langchain", "chat_models", "zhipuai"]
@property def lc_attributes(self) -> Dict[str, Any]: attributes: Dict[str, Any] = {} if self.zhipuai_api_base: attributes["zhipuai_api_base"] = self.zhipuai_api_base return attributes @property def _llm_type(self) -> str: """返回聊天模型的类型。""" return "zhipuai-chat" @property def _default_params(self) -> Dict[str, Any]: """获取调用OpenAI API的默认参数。""" params = { "model": self.model_name, "stream": self.streaming, "temperature": self.temperature, } if self.max_tokens is not None: params["max_tokens"] = self.max_tokens return params # client: zhipuai_api_key: Optional[str] = Field(default=None, alias="api_key") """如果未提供,将自动从环境变量`ZHIPUAI_API_KEY`中推断。""" zhipuai_api_base: Optional[str] = Field(default=None, alias="api_base") """API请求的基本URL路径,如果不使用代理或服务模拟器,请留空。""" model_name: Optional[str] = Field(default="glm-4", alias="model") """模型名称,参见'https://open.bigmodel.cn/dev/api#language'。 或者,您可以使用GLM系列中的任何微调模型。""" temperature: float = 0.95 """要使用的采样温度。该值范围从0.0到1.0,不能等于0。 值越大,输出越随机和创造性;值越小,输出越稳定或确定。 建议根据应用场景调整top_p或温度参数,但不要同时调整这两个参数。""" top_p: float = 0.7 """另一种采样温度的方法称为核采样。该值范围从0.0到1.0,不能等于0或1。 该模型考虑具有top_p概率质量标记的结果。 例如,0.1表示模型解码器仅考虑候选集中概率前10%的标记。 建议根据应用场景调整top_p或温度参数,但不要同时调整这两个参数。""" streaming: bool = False """是否要流式传输结果。""" max_tokens: Optional[int] = None """生成的令牌的最大数量。""" class Config: """此pydantic对象的配置。""" allow_population_by_field_name = True @root_validator() def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]: values["zhipuai_api_key"] = get_from_dict_or_env( values, "zhipuai_api_key", "ZHIPUAI_API_KEY" ) values["zhipuai_api_base"] = get_from_dict_or_env( values, "zhipuai_api_base", "ZHIPUAI_API_BASE", default=ZHIPUAI_API_BASE ) return values def _create_message_dicts( self, messages: List[BaseMessage], stop: Optional[List[str]] ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: params = self._default_params if stop is not None: params["stop"] = stop message_dicts = [_convert_message_to_dict(m) for m in messages] return message_dicts, params def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult: generations = [] if not isinstance(response, dict): response = response.dict() for res in response["choices"]: message = _convert_dict_to_message(res["message"]) generation_info = dict(finish_reason=res.get("finish_reason")) generations.append( ChatGeneration(message=message, generation_info=generation_info) ) token_usage = response.get("usage", {}) llm_output = { "token_usage": token_usage, "model_name": self.model_name, } return ChatResult(generations=generations, llm_output=llm_output) def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, stream: Optional[bool] = None, **kwargs: Any, ) -> ChatResult: """生成一个聊天回复。""" should_stream = stream if stream is not None else self.streaming if should_stream: stream_iter = self._stream( messages, stop=stop, run_manager=run_manager, **kwargs ) return generate_from_stream(stream_iter) if self.zhipuai_api_key is None: raise ValueError("Did not find zhipuai_api_key.") message_dicts, params = self._create_message_dicts(messages, stop) payload = { **params, **kwargs, "messages": message_dicts, "stream": False, } _truncate_params(payload) headers = { "Authorization": _get_jwt_token(self.zhipuai_api_key), "Accept": "application/json", } import httpx with httpx.Client(headers=headers, timeout=60) as client: response = client.post(self.zhipuai_api_base, json=payload) response.raise_for_status() return self._create_chat_result(response.json()) def _stream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: """以块的形式流式传输聊天响应。""" if self.zhipuai_api_key is None: raise ValueError("Did not find zhipuai_api_key.") if self.zhipuai_api_base is None: raise ValueError("Did not find zhipu_api_base.") message_dicts, params = self._create_message_dicts(messages, stop) payload = {**params, **kwargs, "messages": message_dicts, "stream": True} _truncate_params(payload) headers = { "Authorization": _get_jwt_token(self.zhipuai_api_key), "Accept": "application/json", } default_chunk_class = AIMessageChunk import httpx with httpx.Client(headers=headers, timeout=60) as client: with connect_sse( client, "POST", self.zhipuai_api_base, json=payload ) as event_source: for sse in event_source.iter_sse(): chunk = json.loads(sse.data) if len(chunk["choices"]) == 0: continue choice = chunk["choices"][0] chunk = _convert_delta_to_message_chunk( choice["delta"], default_chunk_class ) finish_reason = choice.get("finish_reason", None) generation_info = ( {"finish_reason": finish_reason} if finish_reason is not None else None ) chunk = ChatGenerationChunk( message=chunk, generation_info=generation_info ) yield chunk if run_manager: run_manager.on_llm_new_token(chunk.text, chunk=chunk) if finish_reason is not None: break async def _agenerate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, stream: Optional[bool] = None, **kwargs: Any, ) -> ChatResult: should_stream = stream if stream is not None else self.streaming if should_stream: stream_iter = self._astream( messages, stop=stop, run_manager=run_manager, **kwargs ) return await agenerate_from_stream(stream_iter) if self.zhipuai_api_key is None: raise ValueError("Did not find zhipuai_api_key.") message_dicts, params = self._create_message_dicts(messages, stop) payload = { **params, **kwargs, "messages": message_dicts, "stream": False, } _truncate_params(payload) headers = { "Authorization": _get_jwt_token(self.zhipuai_api_key), "Accept": "application/json", } import httpx async with httpx.AsyncClient(headers=headers, timeout=60) as client: response = await client.post(self.zhipuai_api_base, json=payload) response.raise_for_status() return self._create_chat_result(response.json()) async def _astream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: if self.zhipuai_api_key is None: raise ValueError("Did not find zhipuai_api_key.") if self.zhipuai_api_base is None: raise ValueError("Did not find zhipu_api_base.") message_dicts, params = self._create_message_dicts(messages, stop) payload = {**params, **kwargs, "messages": message_dicts, "stream": True} _truncate_params(payload) headers = { "Authorization": _get_jwt_token(self.zhipuai_api_key), "Accept": "application/json", } default_chunk_class = AIMessageChunk import httpx async with httpx.AsyncClient(headers=headers, timeout=60) as client: async with aconnect_sse( client, "POST", self.zhipuai_api_base, json=payload ) as event_source: async for sse in event_source.aiter_sse(): chunk = json.loads(sse.data) if len(chunk["choices"]) == 0: continue choice = chunk["choices"][0] chunk = _convert_delta_to_message_chunk( choice["delta"], default_chunk_class ) finish_reason = choice.get("finish_reason", None) generation_info = ( {"finish_reason": finish_reason} if finish_reason is not None else None ) chunk = ChatGenerationChunk( message=chunk, generation_info=generation_info ) yield chunk if run_manager: await run_manager.on_llm_new_token(chunk.text, chunk=chunk) if finish_reason is not None: break