Source code for langchain_community.chat_message_histories.rocksetdb

from datetime import datetime
from time import sleep
from typing import Any, Callable, List, Union
from uuid import uuid4

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
    BaseMessage,
    message_to_dict,
    messages_from_dict,
)


[docs]class RocksetChatMessageHistory(BaseChatMessageHistory): """使用Rockset存储聊天消息。 要使用,请确保安装了`rockset` python包。 示例: .. code-block:: python from langchain_community.chat_message_histories import ( RocksetChatMessageHistory ) from rockset import RocksetClient history = RocksetChatMessageHistory( session_id="MySession", client=RocksetClient(), collection="langchain_demo", sync=True ) history.add_user_message("hi!") history.add_ai_message("whats up?") print(history.messages) # noqa: T201""" # You should set these values based on your VI. # These values are configured for the typical # free VI. Read more about VIs here: # https://rockset.com/docs/instances SLEEP_INTERVAL_MS: int = 5 ADD_TIMEOUT_MS: int = 5000 CREATE_TIMEOUT_MS: int = 20000 def _wait_until(self, method: Callable, timeout: int, **method_params: Any) -> None: """等待直到meth()评估为True。将kwargs传递给meth。 """ start = datetime.now() while not method(**method_params): curr = datetime.now() if (curr - start).total_seconds() * 1000 > timeout: raise TimeoutError(f"{method} timed out at {timeout} ms") sleep(RocksetChatMessageHistory.SLEEP_INTERVAL_MS / 1000) def _query(self, query: str, **query_params: Any) -> List[Any]: """执行一个SQL语句并返回结果 参数: - query: SQL字符串 - **query_params: 传递给查询的参数 """ return self.client.sql(query, params=query_params).results def _create_collection(self) -> None: """为此消息历史记录创建一个集合""" self.client.Collections.create_s3_collection( name=self.collection, workspace=self.workspace ) def _collection_exists(self) -> bool: """检查此消息历史记录是否存在集合""" try: self.client.Collections.get(collection=self.collection) except self.rockset.exceptions.NotFoundException: return False return True def _collection_is_ready(self) -> bool: """检查此消息历史记录的集合是否准备好进行查询 """ return ( self.client.Collections.get(collection=self.collection).data.status == "READY" ) def _document_exists(self) -> bool: return ( len( self._query( f""" SELECT 1 FROM {self.location} WHERE _id=:session_id LIMIT 1 """, session_id=self.session_id, ) ) != 0 ) def _wait_until_collection_created(self) -> None: """等待直到此消息历史记录的收集准备好以供查询 """ self._wait_until( lambda: self._collection_is_ready(), RocksetChatMessageHistory.CREATE_TIMEOUT_MS, ) def _wait_until_message_added(self, message_id: str) -> None: """等待直到消息被添加到消息列表中""" self._wait_until( lambda message_id: len( self._query( f""" SELECT * FROM UNNEST(( SELECT {self.messages_key} FROM {self.location} WHERE _id = :session_id )) AS message WHERE message.data.additional_kwargs.id = :message_id LIMIT 1 """, session_id=self.session_id, message_id=message_id, ), ) != 0, RocksetChatMessageHistory.ADD_TIMEOUT_MS, message_id=message_id, ) def _create_empty_doc(self) -> None: """创建或替换此消息历史记录的文档,不包含任何消息。 """ self.client.Documents.add_documents( collection=self.collection, workspace=self.workspace, data=[{"_id": self.session_id, self.messages_key: []}], )
[docs] def __init__( self, session_id: str, client: Any, collection: str, workspace: str = "commons", messages_key: str = "messages", sync: bool = False, message_uuid_method: Callable[[], Union[str, int]] = lambda: str(uuid4()), ) -> None: """构建一个新的RocksetChatMessageHistory。 参数: - session_id: 聊天会话的ID - client: 用于查询的RocksetClient对象 - collection: 用于存储聊天消息的集合名称。如果工作区中不存在具有给定名称的集合,则会创建一个。 - workspace: 包含`collection`的工作区。默认为`"commons"` - messages_key: 包含消息历史记录的数据库列。默认为`"messages"` - sync: 是否等待消息被添加。默认为`False`。注意: 将此设置为`True`会降低性能。 - message_uuid_method: 生成消息ID的方法。如果设置了此参数,所有消息将在`additional_kwargs`属性中具有一个`id`字段。如果未设置此参数且`sync`为`False`,则不会创建消息ID。如果未设置此参数且`sync`为`True`,将使用`uuid.uuid4`方法创建消息ID。 """ try: import rockset except ImportError: raise ImportError( "Could not import rockset client python package. " "Please install it with `pip install rockset`." ) if not isinstance(client, rockset.RocksetClient): raise ValueError( f"client should be an instance of rockset.RocksetClient, " f"got {type(client)}" ) self.session_id = session_id self.client = client self.collection = collection self.workspace = workspace self.location = f'"{self.workspace}"."{self.collection}"' self.rockset = rockset self.messages_key = messages_key self.message_uuid_method = message_uuid_method self.sync = sync try: self.client.set_application("langchain") except AttributeError: # ignore pass if not self._collection_exists(): self._create_collection() self._wait_until_collection_created() self._create_empty_doc() elif not self._document_exists(): self._create_empty_doc()
@property def messages(self) -> List[BaseMessage]: # type: ignore """这个聊天记录中的消息。""" return messages_from_dict( self._query( f""" SELECT * FROM UNNEST (( SELECT "{self.messages_key}" FROM {self.location} WHERE _id = :session_id )) """, session_id=self.session_id, ) )
[docs] def add_message(self, message: BaseMessage) -> None: """将一个消息对象添加到历史记录中。 参数: message: 要存储的BaseMessage对象。 """ if self.sync and "id" not in message.additional_kwargs: message.additional_kwargs["id"] = self.message_uuid_method() self.client.Documents.patch_documents( collection=self.collection, workspace=self.workspace, data=[ self.rockset.model.patch_document.PatchDocument( id=self.session_id, patch=[ self.rockset.model.patch_operation.PatchOperation( op="ADD", path=f"/{self.messages_key}/-", value=message_to_dict(message), ) ], ) ], ) if self.sync: self._wait_until_message_added(message.additional_kwargs["id"])
[docs] def clear(self) -> None: """从聊天记录中删除所有消息""" self._create_empty_doc() if self.sync: self._wait_until( lambda: not self.messages, RocksetChatMessageHistory.ADD_TIMEOUT_MS, )