Source code for langchain_community.chat_message_histories.tidb

import json
import logging
from datetime import datetime
from typing import List, Optional

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, message_to_dict, messages_from_dict
from sqlalchemy import create_engine, text
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker

logger = logging.getLogger(__name__)


[docs]class TiDBChatMessageHistory(BaseChatMessageHistory): """代表存储在TiDB数据库中的聊天消息历史记录。"""
[docs] def __init__( self, session_id: str, connection_string: str, table_name: str = "langchain_message_store", earliest_time: Optional[datetime] = None, ): """初始化 TiDBChatMessageHistory 类的新实例。 参数: session_id (str): 聊天会话的ID。 connection_string (str): TiDB数据库的连接字符串。 格式: mysql+pymysql://<host>:<PASSWORD>@<host>:4000/<db>?ssl_ca=/etc/ssl/cert.pem&ssl_verify_cert=true&ssl_verify_identity=true table_name (str, optional): 存储聊天消息的表名。 默认为 "langchain_message_store"。 earliest_time (Optional[datetime], optional): 检索消息的最早时间。 默认为 None。 """ # noqa self.session_id = session_id self.table_name = table_name self.earliest_time = earliest_time self.cache: List = [] # Set up SQLAlchemy engine and session self.engine = create_engine(connection_string) Session = sessionmaker(bind=self.engine) self.session = Session() self._create_table_if_not_exists() self._load_messages_to_cache()
def _create_table_if_not_exists(self) -> None: """ 如果数据库中不存在表,则创建一个表。 """ create_table_query = text( f""" CREATE TABLE IF NOT EXISTS {self.table_name} ( id INT AUTO_INCREMENT PRIMARY KEY, session_id VARCHAR(255) NOT NULL, message JSON NOT NULL, create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, INDEX session_idx (session_id) );""" ) try: self.session.execute(create_table_query) self.session.commit() except SQLAlchemyError as e: logger.error(f"Error creating table: {e}") self.session.rollback() def _load_messages_to_cache(self) -> None: """从数据库加载消息到缓存中。 该方法从数据库表中检索消息。然后将检索到的消息存储在缓存中,以便更快地访问。 引发: SQLAlchemyError:如果执行数据库查询时出现错误。 """ time_condition = ( f"AND create_time >= '{self.earliest_time}'" if self.earliest_time else "" ) query = text( f""" SELECT message FROM {self.table_name} WHERE session_id = :session_id {time_condition} ORDER BY id; """ ) try: result = self.session.execute(query, {"session_id": self.session_id}) for record in result.fetchall(): message_dict = json.loads(record[0]) self.cache.append(messages_from_dict([message_dict])[0]) except SQLAlchemyError as e: logger.error(f"Error loading messages to cache: {e}") @property def messages(self) -> List[BaseMessage]: # type: ignore[override] """返回所有消息""" if len(self.cache) == 0: self.reload_cache() return self.cache
[docs] def add_message(self, message: BaseMessage) -> None: """向数据库和缓存中添加一条消息""" query = text( f"INSERT INTO {self.table_name} (session_id, message) VALUES (:session_id, :message);" # noqa ) try: self.session.execute( query, { "session_id": self.session_id, "message": json.dumps(message_to_dict(message)), }, ) self.session.commit() self.cache.append(message) except SQLAlchemyError as e: logger.error(f"Error adding message: {e}") self.session.rollback()
[docs] def clear(self) -> None: """清除所有消息""" query = text(f"DELETE FROM {self.table_name} WHERE session_id = :session_id;") try: self.session.execute(query, {"session_id": self.session_id}) self.session.commit() self.cache.clear() except SQLAlchemyError as e: logger.error(f"Error clearing messages: {e}") self.session.rollback()
[docs] def reload_cache(self) -> None: """重新从数据库中加载消息到缓存""" self.cache.clear() self._load_messages_to_cache()
def __del__(self) -> None: """关闭会话""" self.session.close()