Source code for langchain_community.document_loaders.sql_database
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Union
import sqlalchemy as sa
from langchain_community.docstore.document import Document
from langchain_community.document_loaders.base import BaseLoader
from langchain_community.utilities.sql_database import SQLDatabase
[docs]class SQLDatabaseLoader(BaseLoader):
"""通过查询SQLAlchemy支持的数据库表加载文档。
为了与数据库通信,文档加载器使用LangChain集成工具包中的`SQLDatabase`实用程序。
每个文档代表结果集中的一行。"""
[docs] def __init__(
self,
query: Union[str, sa.Select],
db: SQLDatabase,
*,
parameters: Optional[Dict[str, Any]] = None,
page_content_mapper: Optional[Callable[..., str]] = None,
metadata_mapper: Optional[Callable[..., Dict[str, Any]]] = None,
source_columns: Optional[Sequence[str]] = None,
include_rownum_into_metadata: bool = False,
include_query_into_metadata: bool = False,
):
"""参数:
query: 要执行的查询。
db: 包装了一个SQLAlchemy引擎的LangChain `SQLDatabase`。
sqlalchemy_kwargs: 用于SQLAlchemy的`create_engine`的更多关键字参数。
parameters: 可选。要传递给查询的参数。
page_content_mapper: 可选。将行转换为字符串的函数,用作文档的`page_content`。默认情况下,加载器将整行序列化为一个字符串,包括所有列。
metadata_mapper: 可选。将行转换为字典的函数,用作文档的`metadata`。默认情况下,不会选择任何列到元数据字典中。
source_columns: 可选。要用作元数据字典中的`source`的列名。
include_rownum_into_metadata: 可选。是否将行号包含在元数据字典中。默认值:False。
include_query_into_metadata: 可选。是否将查询表达式包含在元数据字典中。默认值:False。
"""
self.query = query
self.db: SQLDatabase = db
self.parameters = parameters or {}
self.page_content_mapper = (
page_content_mapper or self.page_content_default_mapper
)
self.metadata_mapper = metadata_mapper or self.metadata_default_mapper
self.source_columns = source_columns
self.include_rownum_into_metadata = include_rownum_into_metadata
self.include_query_into_metadata = include_query_into_metadata
[docs] def lazy_load(self) -> Iterator[Document]:
try:
import sqlalchemy as sa
except ImportError:
raise ImportError(
"Could not import sqlalchemy python package. "
"Please install it with `pip install sqlalchemy`."
)
# Querying in `cursor` fetch mode will return an SQLAlchemy `Result` instance.
result: sa.Result[Any]
# Invoke the database query.
if isinstance(self.query, sa.SelectBase):
result = self.db._execute( # type: ignore[assignment]
self.query, fetch="cursor", parameters=self.parameters
)
query_sql = str(self.query.compile(bind=self.db._engine))
elif isinstance(self.query, str):
result = self.db._execute( # type: ignore[assignment]
sa.text(self.query), fetch="cursor", parameters=self.parameters
)
query_sql = self.query
else:
raise TypeError(f"Unable to process query of unknown type: {self.query}")
# Iterate database result rows and generate list of documents.
for i, row in enumerate(result.mappings()):
page_content = self.page_content_mapper(row)
metadata = self.metadata_mapper(row)
if self.include_rownum_into_metadata:
metadata["row"] = i
if self.include_query_into_metadata:
metadata["query"] = query_sql
source_values = []
for column, value in row.items():
if self.source_columns and column in self.source_columns:
source_values.append(value)
if source_values:
metadata["source"] = ",".join(source_values)
yield Document(page_content=page_content, metadata=metadata)
[docs] @staticmethod
def page_content_default_mapper(
row: sa.RowMapping, column_names: Optional[List[str]] = None
) -> str:
"""
一个合理的默认函数,用于将记录转换为“页面内容”字符串。
"""
if column_names is None:
column_names = list(row.keys())
return "\n".join(
f"{column}: {value}"
for column, value in row.items()
if column in column_names
)