import json
import logging
import re
from typing import (
Any,
List,
)
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
)
logger = logging.getLogger(__name__)
[docs]class SingleStoreDBChatMessageHistory(BaseChatMessageHistory):
"""在SingleStoreDB数据库中存储的聊天消息历史。"""
[docs] def __init__(
self,
session_id: str,
*,
table_name: str = "message_store",
id_field: str = "id",
session_id_field: str = "session_id",
message_field: str = "message",
pool_size: int = 5,
max_overflow: int = 10,
timeout: float = 30,
**kwargs: Any,
):
"""初始化必要组件。
参数:
table_name(str,可选):指定正在使用的表的名称。默认为“message_store”。
id_field(str,可选):指定表中id字段的名称。默认为“id”。
session_id_field(str,可选):指定表中session_id字段的名称。默认为“session_id”。
message_field(str,可选):指定表中message字段的名称。默认为“message”。
以下参数与连接池相关:
pool_size(int,可选):确定池中活动连接的数量。默认为5。
max_overflow(int,可选):确定允许超出pool_size的最大连接数。默认为10。
timeout(float,可选):指定建立连接的最大等待时间(以秒为单位)。默认为30。
以下参数与数据库连接相关:
host(str,可选):指定数据库连接的主机名、IP地址或URL。默认方案为“mysql”。
user(str,可选):数据库用户名。
password(str,可选):数据库密码。
port(int,可选):数据库端口。对于非HTTP连接,默认为3306,对于HTTP连接,默认为80,对于HTTPS连接,默认为443。
database(str,可选):数据库名称。
其他可选参数可进一步定制数据库连接:
pure_python(bool,可选):切换连接器模式。如果为True,则以纯Python模式运行。
local_infile(bool,可选):允许本地文件上传。
charset(str,可选):指定字符串值的字符集。
ssl_key(str,可选):指定包含SSL密钥的文件路径。
ssl_cert(str,可选):指定包含SSL证书的文件路径。
ssl_ca(str,可选):指定包含SSL证书颁发机构的文件路径。
ssl_cipher(str,可选):设置SSL密码列表。
ssl_disabled(bool,可选):禁用SSL使用。
ssl_verify_cert(bool,可选):验证服务器的证书。如果指定了“ssl_ca”,则自动启用。
ssl_verify_identity(bool,可选):验证服务器的身份。
conv(dict[int,Callable],可选):数据转换函数的字典。
credential_type(str,可选):指定要使用的身份验证类型:auth.PASSWORD、auth.JWT或auth.BROWSER_SSO。
autocommit(bool,可选):启用自动提交。
results_type(str,可选):确定查询结果的结构:元组、命名元组、字典。
results_format(str,可选):已弃用。此选项已更名为results_type。
示例:
基本用法:
.. code-block:: python
from langchain_community.chat_message_histories import (
SingleStoreDBChatMessageHistory
)
message_history = SingleStoreDBChatMessageHistory(
session_id="my-session",
host="https://user:password@127.0.0.1:3306/database"
)
高级用法:
.. code-block:: python
from langchain_community.chat_message_histories import (
SingleStoreDBChatMessageHistory
)
message_history = SingleStoreDBChatMessageHistory(
session_id="my-session",
host="127.0.0.1",
port=3306,
user="user",
password="password",
database="db",
table_name="my_custom_table",
pool_size=10,
timeout=60,
)
使用环境变量:
.. code-block:: python
from langchain_community.chat_message_histories import (
SingleStoreDBChatMessageHistory
)
os.environ['SINGLESTOREDB_URL'] = 'me:p455w0rd@s2-host.com/my_db'
message_history = SingleStoreDBChatMessageHistory("my-session")
"""
self.table_name = self._sanitize_input(table_name)
self.session_id = self._sanitize_input(session_id)
self.id_field = self._sanitize_input(id_field)
self.session_id_field = self._sanitize_input(session_id_field)
self.message_field = self._sanitize_input(message_field)
# Pass the rest of the kwargs to the connection.
self.connection_kwargs = kwargs
# Add connection attributes to the connection kwargs.
if "conn_attrs" not in self.connection_kwargs:
self.connection_kwargs["conn_attrs"] = dict()
self.connection_kwargs["conn_attrs"]["_connector_name"] = "langchain python sdk"
self.connection_kwargs["conn_attrs"]["_connector_version"] = "1.0.1"
# Create a connection pool.
try:
from sqlalchemy.pool import QueuePool
except ImportError:
raise ImportError(
"Could not import sqlalchemy.pool python package. "
"Please install it with `pip install singlestoredb`."
)
self.connection_pool = QueuePool(
self._get_connection,
max_overflow=max_overflow,
pool_size=pool_size,
timeout=timeout,
)
self.table_created = False
def _sanitize_input(self, input_str: str) -> str:
# Remove characters that are not alphanumeric or underscores
return re.sub(r"[^a-zA-Z0-9_]", "", input_str)
def _get_connection(self) -> Any:
try:
import singlestoredb as s2
except ImportError:
raise ImportError(
"Could not import singlestoredb python package. "
"Please install it with `pip install singlestoredb`."
)
return s2.connect(**self.connection_kwargs)
def _create_table_if_not_exists(self) -> None:
"""如果表不存在,则创建表。"""
if self.table_created:
return
conn = self.connection_pool.connect()
try:
cur = conn.cursor()
try:
cur.execute(
"""CREATE TABLE IF NOT EXISTS {}
({} BIGINT PRIMARY KEY AUTO_INCREMENT,
{} TEXT NOT NULL,
{} JSON NOT NULL);""".format(
self.table_name,
self.id_field,
self.session_id_field,
self.message_field,
),
)
self.table_created = True
finally:
cur.close()
finally:
conn.close()
@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""从SingleStoreDB检索消息"""
self._create_table_if_not_exists()
conn = self.connection_pool.connect()
items = []
try:
cur = conn.cursor()
try:
cur.execute(
"""SELECT {} FROM {} WHERE {} = %s""".format(
self.message_field,
self.table_name,
self.session_id_field,
),
(self.session_id),
)
for row in cur.fetchall():
items.append(row[0])
finally:
cur.close()
finally:
conn.close()
messages = messages_from_dict(items)
return messages
[docs] def add_message(self, message: BaseMessage) -> None:
"""将消息附加到SingleStoreDB中的记录"""
self._create_table_if_not_exists()
conn = self.connection_pool.connect()
try:
cur = conn.cursor()
try:
cur.execute(
"""INSERT INTO {} ({}, {}) VALUES (%s, %s)""".format(
self.table_name,
self.session_id_field,
self.message_field,
),
(self.session_id, json.dumps(message_to_dict(message))),
)
finally:
cur.close()
finally:
conn.close()
[docs] def clear(self) -> None:
"""清除SingleStoreDB中的会话内存"""
self._create_table_if_not_exists()
conn = self.connection_pool.connect()
try:
cur = conn.cursor()
try:
cur.execute(
"""DELETE FROM {} WHERE {} = %s""".format(
self.table_name,
self.session_id_field,
),
(self.session_id),
)
finally:
cur.close()
finally:
conn.close()