Source code for langchain_community.chat_message_histories.cosmos_db
"""Azure CosmosDB 内存历史。"""
from __future__ import annotations
import logging
from types import TracebackType
from typing import TYPE_CHECKING, Any, List, Optional, Type
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
messages_from_dict,
messages_to_dict,
)
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from azure.cosmos import ContainerProxy
[docs]class CosmosDBChatMessageHistory(BaseChatMessageHistory):
"""使用Azure CosmosDB支持的聊天消息历史记录。"""
[docs] def __init__(
self,
cosmos_endpoint: str,
cosmos_database: str,
cosmos_container: str,
session_id: str,
user_id: str,
credential: Any = None,
connection_string: Optional[str] = None,
ttl: Optional[int] = None,
cosmos_client_kwargs: Optional[dict] = None,
):
"""初始化一个新的 CosmosDBChatMessageHistory 类的实例。
确保调用 prepare_cosmos 或使用上下文管理器来确保你的数据库已经准备就绪。
必须提供凭据或连接字符串之一。
:param cosmos_endpoint: Azure Cosmos DB 账户的连接端点。
:param cosmos_database: 要使用的数据库的名称。
:param cosmos_container: 要使用的容器的名称。
:param session_id: 要使用的会话 ID,在加载时可以被覆盖。
:param user_id: 要使用的用户 ID,在加载时可以被覆盖。
:param credential: 用于在 Azure Cosmos DB 进行身份验证的凭据。
:param connection_string: 用于进行身份验证的连接字符串。
:param ttl: 用于容器中文档的存活时间(以秒为单位)。
:param cosmos_client_kwargs: 传递给 CosmosClient 的额外 kwargs。
"""
self.cosmos_endpoint = cosmos_endpoint
self.cosmos_database = cosmos_database
self.cosmos_container = cosmos_container
self.credential = credential
self.conn_string = connection_string
self.session_id = session_id
self.user_id = user_id
self.ttl = ttl
self.messages: List[BaseMessage] = []
try:
from azure.cosmos import ( # pylint: disable=import-outside-toplevel # noqa: E501
CosmosClient,
)
except ImportError as exc:
raise ImportError(
"You must install the azure-cosmos package to use the CosmosDBChatMessageHistory." # noqa: E501
"Please install it with `pip install azure-cosmos`."
) from exc
if self.credential:
self._client = CosmosClient(
url=self.cosmos_endpoint,
credential=self.credential,
**cosmos_client_kwargs or {},
)
elif self.conn_string:
self._client = CosmosClient.from_connection_string(
conn_str=self.conn_string,
**cosmos_client_kwargs or {},
)
else:
raise ValueError("Either a connection string or a credential must be set.")
self._container: Optional[ContainerProxy] = None
[docs] def prepare_cosmos(self) -> None:
"""准备CosmosDB客户端。
使用此函数或上下文管理器确保您的数据库已准备就绪。
"""
try:
from azure.cosmos import ( # pylint: disable=import-outside-toplevel # noqa: E501
PartitionKey,
)
except ImportError as exc:
raise ImportError(
"You must install the azure-cosmos package to use the CosmosDBChatMessageHistory." # noqa: E501
"Please install it with `pip install azure-cosmos`."
) from exc
database = self._client.create_database_if_not_exists(self.cosmos_database)
self._container = database.create_container_if_not_exists(
self.cosmos_container,
partition_key=PartitionKey("/user_id"),
default_ttl=self.ttl,
)
self.load_messages()
def __enter__(self) -> "CosmosDBChatMessageHistory":
"""上下文管理器入口点。"""
self._client.__enter__()
self.prepare_cosmos()
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""上下文管理器退出"""
self.upsert_messages()
self._client.__exit__(exc_type, exc_val, traceback)
[docs] def load_messages(self) -> None:
"""从 Cosmos 检索消息"""
if not self._container:
raise ValueError("Container not initialized")
try:
from azure.cosmos.exceptions import ( # pylint: disable=import-outside-toplevel # noqa: E501
CosmosHttpResponseError,
)
except ImportError as exc:
raise ImportError(
"You must install the azure-cosmos package to use the CosmosDBChatMessageHistory." # noqa: E501
"Please install it with `pip install azure-cosmos`."
) from exc
try:
item = self._container.read_item(
item=self.session_id, partition_key=self.user_id
)
except CosmosHttpResponseError:
logger.info("no session found")
return
if "messages" in item and len(item["messages"]) > 0:
self.messages = messages_from_dict(item["messages"])
[docs] def add_message(self, message: BaseMessage) -> None:
"""向商店添加一个自定义消息"""
self.messages.append(message)
self.upsert_messages()
[docs] def upsert_messages(self) -> None:
"""更新cosmosdb项。"""
if not self._container:
raise ValueError("Container not initialized")
self._container.upsert_item(
body={
"id": self.session_id,
"user_id": self.user_id,
"messages": messages_to_dict(self.messages),
}
)
[docs] def clear(self) -> None:
"""清除此内存和宇宙中的会话内存。"""
self.messages = []
if self._container:
self._container.delete_item(
item=self.session_id, partition_key=self.user_id
)