Source code for langchain_community.document_loaders.sql_database

from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Union

import sqlalchemy as sa

from langchain_community.docstore.document import Document
from langchain_community.document_loaders.base import BaseLoader
from langchain_community.utilities.sql_database import SQLDatabase


[docs]class SQLDatabaseLoader(BaseLoader): """通过查询SQLAlchemy支持的数据库表加载文档。 为了与数据库通信,文档加载器使用LangChain集成工具包中的`SQLDatabase`实用程序。 每个文档代表结果集中的一行。"""
[docs] def __init__( self, query: Union[str, sa.Select], db: SQLDatabase, *, parameters: Optional[Dict[str, Any]] = None, page_content_mapper: Optional[Callable[..., str]] = None, metadata_mapper: Optional[Callable[..., Dict[str, Any]]] = None, source_columns: Optional[Sequence[str]] = None, include_rownum_into_metadata: bool = False, include_query_into_metadata: bool = False, ): """参数: query: 要执行的查询。 db: 包装了一个SQLAlchemy引擎的LangChain `SQLDatabase`。 sqlalchemy_kwargs: 用于SQLAlchemy的`create_engine`的更多关键字参数。 parameters: 可选。要传递给查询的参数。 page_content_mapper: 可选。将行转换为字符串的函数,用作文档的`page_content`。默认情况下,加载器将整行序列化为一个字符串,包括所有列。 metadata_mapper: 可选。将行转换为字典的函数,用作文档的`metadata`。默认情况下,不会选择任何列到元数据字典中。 source_columns: 可选。要用作元数据字典中的`source`的列名。 include_rownum_into_metadata: 可选。是否将行号包含在元数据字典中。默认值:False。 include_query_into_metadata: 可选。是否将查询表达式包含在元数据字典中。默认值:False。 """ self.query = query self.db: SQLDatabase = db self.parameters = parameters or {} self.page_content_mapper = ( page_content_mapper or self.page_content_default_mapper ) self.metadata_mapper = metadata_mapper or self.metadata_default_mapper self.source_columns = source_columns self.include_rownum_into_metadata = include_rownum_into_metadata self.include_query_into_metadata = include_query_into_metadata
[docs] def lazy_load(self) -> Iterator[Document]: try: import sqlalchemy as sa except ImportError: raise ImportError( "Could not import sqlalchemy python package. " "Please install it with `pip install sqlalchemy`." ) # Querying in `cursor` fetch mode will return an SQLAlchemy `Result` instance. result: sa.Result[Any] # Invoke the database query. if isinstance(self.query, sa.SelectBase): result = self.db._execute( # type: ignore[assignment] self.query, fetch="cursor", parameters=self.parameters ) query_sql = str(self.query.compile(bind=self.db._engine)) elif isinstance(self.query, str): result = self.db._execute( # type: ignore[assignment] sa.text(self.query), fetch="cursor", parameters=self.parameters ) query_sql = self.query else: raise TypeError(f"Unable to process query of unknown type: {self.query}") # Iterate database result rows and generate list of documents. for i, row in enumerate(result.mappings()): page_content = self.page_content_mapper(row) metadata = self.metadata_mapper(row) if self.include_rownum_into_metadata: metadata["row"] = i if self.include_query_into_metadata: metadata["query"] = query_sql source_values = [] for column, value in row.items(): if self.source_columns and column in self.source_columns: source_values.append(value) if source_values: metadata["source"] = ",".join(source_values) yield Document(page_content=page_content, metadata=metadata)
[docs] @staticmethod def page_content_default_mapper( row: sa.RowMapping, column_names: Optional[List[str]] = None ) -> str: """ 一个合理的默认函数,用于将记录转换为“页面内容”字符串。 """ if column_names is None: column_names = list(row.keys()) return "\n".join( f"{column}: {value}" for column, value in row.items() if column in column_names )
[docs] @staticmethod def metadata_default_mapper( row: sa.RowMapping, column_names: Optional[List[str]] = None ) -> Dict[str, Any]: """ 一个合理的默认函数,将记录转换为“元数据”字典。 """ if column_names is None: return {} metadata: Dict[str, Any] = {} for column, value in row.items(): if column in column_names: metadata[column] = value return metadata