Source code for langchain_community.chat_message_histories.cassandra
"""基于cassIO的基于Cassandra的聊天消息记录。"""
from __future__ import annotations
import json
import uuid
from typing import TYPE_CHECKING, List, Optional
if TYPE_CHECKING:
from cassandra.cluster import Session
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
)
DEFAULT_TABLE_NAME = "message_store"
DEFAULT_TTL_SECONDS = None
[docs]class CassandraChatMessageHistory(BaseChatMessageHistory):
"""存储在Cassandra中的聊天消息历史记录。
参数:
session_id: 用于存储单个聊天会话消息的任意键。
session: Cassandra驱动程序会话。如果未提供,则从cassio中解析。
keyspace: Cassandra键空间。如果未提供,则从cassio中解析。
table_name: 要使用的表的名称。
ttl_seconds: 存储条目自动过期的生存时间(秒)。默认值为None,表示不过期。"""
[docs] def __init__(
self,
session_id: str,
session: Optional[Session] = None,
keyspace: Optional[str] = None,
table_name: str = DEFAULT_TABLE_NAME,
ttl_seconds: Optional[int] = DEFAULT_TTL_SECONDS,
) -> None:
try:
from cassio.table import ClusteredCassandraTable
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Could not import cassio python package. "
"Please install it with `pip install cassio`."
)
self.session_id = session_id
self.ttl_seconds = ttl_seconds
self.table = ClusteredCassandraTable(
session=session,
keyspace=keyspace,
table=table_name,
ttl_seconds=ttl_seconds,
primary_key_type=["TEXT", "TIMEUUID"],
ordering_in_partition="DESC",
)
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""从数据库中检索所有会话消息"""
# The latest are returned, in chronological order
message_blobs = [
row["body_blob"]
for row in self.table.get_partition(
partition_id=self.session_id,
)
][::-1]
items = [json.loads(message_blob) for message_blob in message_blobs]
messages = messages_from_dict(items)
return messages
[docs] def add_message(self, message: BaseMessage) -> None:
"""将消息写入表格
参数:
message: 要写入的消息。
"""
this_row_id = uuid.uuid1()
self.table.put(
partition_id=self.session_id,
row_id=this_row_id,
body_blob=json.dumps(message_to_dict(message)),
ttl_seconds=self.ttl_seconds,
)
[docs] def clear(self) -> None:
"""清除数据库中的会话内存"""
self.table.delete_partition(self.session_id)