"""Astra DB - 基于astrapy的聊天消息记录。"""
from __future__ import annotations
import json
import time
from typing import TYPE_CHECKING, List, Optional, Sequence
from langchain_community.utilities.astradb import (
SetupMode,
_AstraDBCollectionEnvironment,
)
if TYPE_CHECKING:
from astrapy.db import AstraDB, AsyncAstraDB
from langchain_core._api.deprecation import deprecated
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
)
DEFAULT_COLLECTION_NAME = "langchain_message_store"
[docs]@deprecated(
since="0.0.25",
removal="0.3.0",
alternative_import="langchain_astradb.AstraDBChatMessageHistory",
)
class AstraDBChatMessageHistory(BaseChatMessageHistory):
[docs] def __init__(
self,
*,
session_id: str,
collection_name: str = DEFAULT_COLLECTION_NAME,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
namespace: Optional[str] = None,
setup_mode: SetupMode = SetupMode.SYNC,
pre_delete_collection: bool = False,
) -> None:
"""聊天消息历史记录,将历史记录存储在Astra DB中。
参数:
session_id: 用于存储单个聊天会话消息的任意键。
collection_name: 要创建/使用的Astra DB集合的名称。
token: 用于Astra DB使用的API令牌。
api_endpoint: API端点的完整URL,例如"https://<DB-ID>-us-east1.apps.astra.datastax.com"。
astra_db_client: *token+api_endpoint的替代方案*,可以传递一个已创建的'astrapy.db.AstraDB'实例。
async_astra_db_client: *token+api_endpoint的替代方案*,可以传递一个已创建的'astrapy.db.AsyncAstraDB'实例。
namespace: 创建集合的命名空间(又名键空间)。默认为数据库的“默认命名空间”。
setup_mode: 用于创建Astra DB集合的模式(SYNC、ASYNC或OFF)。
pre_delete_collection: 是否在创建集合之前删除集合。如果为False且集合已经存在,则将直接使用该集合。
"""
self.astra_env = _AstraDBCollectionEnvironment(
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
)
self.collection = self.astra_env.collection
self.async_collection = self.astra_env.async_collection
self.session_id = session_id
self.collection_name = collection_name
@property
def messages(self) -> List[BaseMessage]:
"""从数据库中检索所有会话消息"""
self.astra_env.ensure_db_setup()
message_blobs = [
doc["body_blob"]
for doc in sorted(
self.collection.paginated_find(
filter={
"session_id": self.session_id,
},
projection={
"timestamp": 1,
"body_blob": 1,
},
),
key=lambda _doc: _doc["timestamp"],
)
]
items = [json.loads(message_blob) for message_blob in message_blobs]
messages = messages_from_dict(items)
return messages
@messages.setter
def messages(self, messages: List[BaseMessage]) -> None:
raise NotImplementedError("Use add_messages instead")
[docs] async def aget_messages(self) -> List[BaseMessage]:
await self.astra_env.aensure_db_setup()
docs = self.async_collection.paginated_find(
filter={
"session_id": self.session_id,
},
projection={
"timestamp": 1,
"body_blob": 1,
},
)
sorted_docs = sorted(
[doc async for doc in docs],
key=lambda _doc: _doc["timestamp"],
)
message_blobs = [doc["body_blob"] for doc in sorted_docs]
items = [json.loads(message_blob) for message_blob in message_blobs]
messages = messages_from_dict(items)
return messages
[docs] def add_messages(self, messages: Sequence[BaseMessage]) -> None:
self.astra_env.ensure_db_setup()
docs = [
{
"timestamp": time.time(),
"session_id": self.session_id,
"body_blob": json.dumps(message_to_dict(message)),
}
for message in messages
]
self.collection.chunked_insert_many(docs)
[docs] async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
await self.astra_env.aensure_db_setup()
docs = [
{
"timestamp": time.time(),
"session_id": self.session_id,
"body_blob": json.dumps(message_to_dict(message)),
}
for message in messages
]
await self.async_collection.chunked_insert_many(docs)
[docs] def clear(self) -> None:
self.astra_env.ensure_db_setup()
self.collection.delete_many(filter={"session_id": self.session_id})
[docs] async def aclear(self) -> None:
await self.astra_env.aensure_db_setup()
await self.async_collection.delete_many(filter={"session_id": self.session_id})