from typing import Any, Dict, List, Literal, Union
from langchain_core.messages.base import (
BaseMessage,
BaseMessageChunk,
merge_content,
)
from langchain_core.messages.tool import (
InvalidToolCall,
ToolCall,
ToolCallChunk,
default_tool_chunk_parser,
default_tool_parser,
)
from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils._merge import merge_dicts, merge_lists
from langchain_core.utils.json import (
parse_partial_json,
)
[docs]class AIMessage(BaseMessage):
"""来自人工智能的消息。"""
example: bool = False
"""该消息是否作为示例对话的一部分传递给模型。"""
tool_calls: List[ToolCall] = []
"""如果提供了,工具调用与消息相关联。"""
invalid_tool_calls: List[InvalidToolCall] = []
"""如果提供了工具调用,则会出现与消息相关的解析错误。"""
type: Literal["ai"] = "ai"
[docs] @classmethod
def get_lc_namespace(cls) -> List[str]:
"""获取langchain对象的命名空间。"""
return ["langchain", "schema", "messages"]
@property
def lc_attributes(self) -> Dict:
"""需要序列化的属性,即使它们是从其他初始化参数派生而来。"""
return {
"tool_calls": self.tool_calls,
"invalid_tool_calls": self.invalid_tool_calls,
}
@root_validator()
def _backwards_compat_tool_calls(cls, values: dict) -> dict:
raw_tool_calls = values.get("additional_kwargs", {}).get("tool_calls")
tool_calls = (
values.get("tool_calls")
or values.get("invalid_tool_calls")
or values.get("tool_call_chunks")
)
if raw_tool_calls and not tool_calls:
try:
if issubclass(cls, AIMessageChunk): # type: ignore
values["tool_call_chunks"] = default_tool_chunk_parser(
raw_tool_calls
)
else:
tool_calls, invalid_tool_calls = default_tool_parser(raw_tool_calls)
values["tool_calls"] = tool_calls
values["invalid_tool_calls"] = invalid_tool_calls
except Exception:
pass
return values
[docs] def pretty_repr(self, html: bool = False) -> str:
"""返回消息的漂亮表示形式。"""
base = super().pretty_repr(html=html)
lines = []
def _format_tool_args(tc: Union[ToolCall, InvalidToolCall]) -> List[str]:
lines = [
f" {tc.get('name', 'Tool')} ({tc.get('id')})",
f" Call ID: {tc.get('id')}",
]
if tc.get("error"):
lines.append(f" Error: {tc.get('error')}")
lines.append(" Args:")
args = tc.get("args")
if isinstance(args, str):
lines.append(f" {args}")
elif isinstance(args, dict):
for arg, value in args.items():
lines.append(f" {arg}: {value}")
return lines
if self.tool_calls:
lines.append("Tool Calls:")
for tc in self.tool_calls:
lines.extend(_format_tool_args(tc))
if self.invalid_tool_calls:
lines.append("Invalid Tool Calls:")
for itc in self.invalid_tool_calls:
lines.extend(_format_tool_args(itc))
return (base.strip() + "\n" + "\n".join(lines)).strip()
AIMessage.update_forward_refs()
[docs]class AIMessageChunk(AIMessage, BaseMessageChunk):
"""来自人工智能的消息块。"""
# Ignoring mypy re-assignment here since we're overriding the value
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["AIMessageChunk"] = "AIMessageChunk" # type: ignore[assignment] # noqa: E501
tool_call_chunks: List[ToolCallChunk] = []
"""如果提供了,工具调用与消息相关的块。"""
[docs] @classmethod
def get_lc_namespace(cls) -> List[str]:
"""获取langchain对象的命名空间。"""
return ["langchain", "schema", "messages"]
@property
def lc_attributes(self) -> Dict:
"""需要序列化的属性,即使它们是从其他初始化参数派生而来。"""
return {
"tool_calls": self.tool_calls,
"invalid_tool_calls": self.invalid_tool_calls,
}
@root_validator()
def init_tool_calls(cls, values: dict) -> dict:
if not values["tool_call_chunks"]:
values["tool_calls"] = []
values["invalid_tool_calls"] = []
return values
tool_calls = []
invalid_tool_calls = []
for chunk in values["tool_call_chunks"]:
try:
args_ = parse_partial_json(chunk["args"])
if isinstance(args_, dict):
tool_calls.append(
ToolCall(
name=chunk["name"] or "",
args=args_,
id=chunk["id"],
)
)
else:
raise ValueError("Malformed args.")
except Exception:
invalid_tool_calls.append(
InvalidToolCall(
name=chunk["name"],
args=chunk["args"],
id=chunk["id"],
error=None,
)
)
values["tool_calls"] = tool_calls
values["invalid_tool_calls"] = invalid_tool_calls
return values
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, AIMessageChunk):
if self.example != other.example:
raise ValueError(
"Cannot concatenate AIMessageChunks with different example values."
)
content = merge_content(self.content, other.content)
additional_kwargs = merge_dicts(
self.additional_kwargs, other.additional_kwargs
)
response_metadata = merge_dicts(
self.response_metadata, other.response_metadata
)
# Merge tool call chunks
if self.tool_call_chunks or other.tool_call_chunks:
raw_tool_calls = merge_lists(
self.tool_call_chunks,
other.tool_call_chunks,
)
if raw_tool_calls:
tool_call_chunks = [
ToolCallChunk(
name=rtc.get("name"),
args=rtc.get("args"),
index=rtc.get("index"),
id=rtc.get("id"),
)
for rtc in raw_tool_calls
]
else:
tool_call_chunks = []
else:
tool_call_chunks = []
return self.__class__(
example=self.example,
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
response_metadata=response_metadata,
id=self.id,
)
return super().__add__(other)