from datetime import datetime
from time import sleep
from typing import Any, Callable, List, Union
from uuid import uuid4
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
)
[docs]class RocksetChatMessageHistory(BaseChatMessageHistory):
"""使用Rockset存储聊天消息。
要使用,请确保安装了`rockset` python包。
示例:
.. code-block:: python
from langchain_community.chat_message_histories import (
RocksetChatMessageHistory
)
from rockset import RocksetClient
history = RocksetChatMessageHistory(
session_id="MySession",
client=RocksetClient(),
collection="langchain_demo",
sync=True
)
history.add_user_message("hi!")
history.add_ai_message("whats up?")
print(history.messages) # noqa: T201"""
# You should set these values based on your VI.
# These values are configured for the typical
# free VI. Read more about VIs here:
# https://rockset.com/docs/instances
SLEEP_INTERVAL_MS: int = 5
ADD_TIMEOUT_MS: int = 5000
CREATE_TIMEOUT_MS: int = 20000
def _wait_until(self, method: Callable, timeout: int, **method_params: Any) -> None:
"""等待直到meth()评估为True。将kwargs传递给meth。
"""
start = datetime.now()
while not method(**method_params):
curr = datetime.now()
if (curr - start).total_seconds() * 1000 > timeout:
raise TimeoutError(f"{method} timed out at {timeout} ms")
sleep(RocksetChatMessageHistory.SLEEP_INTERVAL_MS / 1000)
def _query(self, query: str, **query_params: Any) -> List[Any]:
"""执行一个SQL语句并返回结果
参数:
- query: SQL字符串
- **query_params: 传递给查询的参数
"""
return self.client.sql(query, params=query_params).results
def _create_collection(self) -> None:
"""为此消息历史记录创建一个集合"""
self.client.Collections.create_s3_collection(
name=self.collection, workspace=self.workspace
)
def _collection_exists(self) -> bool:
"""检查此消息历史记录是否存在集合"""
try:
self.client.Collections.get(collection=self.collection)
except self.rockset.exceptions.NotFoundException:
return False
return True
def _collection_is_ready(self) -> bool:
"""检查此消息历史记录的集合是否准备好进行查询
"""
return (
self.client.Collections.get(collection=self.collection).data.status
== "READY"
)
def _document_exists(self) -> bool:
return (
len(
self._query(
f"""
SELECT 1
FROM {self.location}
WHERE _id=:session_id
LIMIT 1
""",
session_id=self.session_id,
)
)
!= 0
)
def _wait_until_collection_created(self) -> None:
"""等待直到此消息历史记录的收集准备好以供查询
"""
self._wait_until(
lambda: self._collection_is_ready(),
RocksetChatMessageHistory.CREATE_TIMEOUT_MS,
)
def _wait_until_message_added(self, message_id: str) -> None:
"""等待直到消息被添加到消息列表中"""
self._wait_until(
lambda message_id: len(
self._query(
f"""
SELECT *
FROM UNNEST((
SELECT {self.messages_key}
FROM {self.location}
WHERE _id = :session_id
)) AS message
WHERE message.data.additional_kwargs.id = :message_id
LIMIT 1
""",
session_id=self.session_id,
message_id=message_id,
),
)
!= 0,
RocksetChatMessageHistory.ADD_TIMEOUT_MS,
message_id=message_id,
)
def _create_empty_doc(self) -> None:
"""创建或替换此消息历史记录的文档,不包含任何消息。
"""
self.client.Documents.add_documents(
collection=self.collection,
workspace=self.workspace,
data=[{"_id": self.session_id, self.messages_key: []}],
)
[docs] def __init__(
self,
session_id: str,
client: Any,
collection: str,
workspace: str = "commons",
messages_key: str = "messages",
sync: bool = False,
message_uuid_method: Callable[[], Union[str, int]] = lambda: str(uuid4()),
) -> None:
"""构建一个新的RocksetChatMessageHistory。
参数:
- session_id: 聊天会话的ID
- client: 用于查询的RocksetClient对象
- collection: 用于存储聊天消息的集合名称。如果工作区中不存在具有给定名称的集合,则会创建一个。
- workspace: 包含`collection`的工作区。默认为`"commons"`
- messages_key: 包含消息历史记录的数据库列。默认为`"messages"`
- sync: 是否等待消息被添加。默认为`False`。注意: 将此设置为`True`会降低性能。
- message_uuid_method: 生成消息ID的方法。如果设置了此参数,所有消息将在`additional_kwargs`属性中具有一个`id`字段。如果未设置此参数且`sync`为`False`,则不会创建消息ID。如果未设置此参数且`sync`为`True`,将使用`uuid.uuid4`方法创建消息ID。
"""
try:
import rockset
except ImportError:
raise ImportError(
"Could not import rockset client python package. "
"Please install it with `pip install rockset`."
)
if not isinstance(client, rockset.RocksetClient):
raise ValueError(
f"client should be an instance of rockset.RocksetClient, "
f"got {type(client)}"
)
self.session_id = session_id
self.client = client
self.collection = collection
self.workspace = workspace
self.location = f'"{self.workspace}"."{self.collection}"'
self.rockset = rockset
self.messages_key = messages_key
self.message_uuid_method = message_uuid_method
self.sync = sync
try:
self.client.set_application("langchain")
except AttributeError:
# ignore
pass
if not self._collection_exists():
self._create_collection()
self._wait_until_collection_created()
self._create_empty_doc()
elif not self._document_exists():
self._create_empty_doc()
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""这个聊天记录中的消息。"""
return messages_from_dict(
self._query(
f"""
SELECT *
FROM UNNEST ((
SELECT "{self.messages_key}"
FROM {self.location}
WHERE _id = :session_id
))
""",
session_id=self.session_id,
)
)
[docs] def add_message(self, message: BaseMessage) -> None:
"""将一个消息对象添加到历史记录中。
参数:
message: 要存储的BaseMessage对象。
"""
if self.sync and "id" not in message.additional_kwargs:
message.additional_kwargs["id"] = self.message_uuid_method()
self.client.Documents.patch_documents(
collection=self.collection,
workspace=self.workspace,
data=[
self.rockset.model.patch_document.PatchDocument(
id=self.session_id,
patch=[
self.rockset.model.patch_operation.PatchOperation(
op="ADD",
path=f"/{self.messages_key}/-",
value=message_to_dict(message),
)
],
)
],
)
if self.sync:
self._wait_until_message_added(message.additional_kwargs["id"])
[docs] def clear(self) -> None:
"""从聊天记录中删除所有消息"""
self._create_empty_doc()
if self.sync:
self._wait_until(
lambda: not self.messages,
RocksetChatMessageHistory.ADD_TIMEOUT_MS,
)