Source code for langchain_community.document_loaders.oracleadb_loader

from typing import Any, Dict, List, Optional

from langchain_core.documents import Document

from langchain_community.document_loaders.base import BaseLoader


[docs]class OracleAutonomousDatabaseLoader(BaseLoader): """从Oracle adb加载 自主数据库连接可以通过connection_string或tns名称进行。 对于TLS连接,需要提供wallet_location和wallet_password。 每个文档将代表查询结果的一行。 列将被写入构造函数中的`page_content`和'metadata','metadata'默认为None。"""
[docs] def __init__( self, query: str, user: str, password: str, *, schema: Optional[str] = None, tns_name: Optional[str] = None, config_dir: Optional[str] = None, wallet_location: Optional[str] = None, wallet_password: Optional[str] = None, connection_string: Optional[str] = None, metadata: Optional[List[str]] = None, ): """初始化方法 :param query: 要执行的SQL查询 :param user: 用户名 :param password: 用户密码 :param schema: 数据库中要运行的模式 :param tns_name: tnsname.ora中的TNS名称 :param config_dir: 配置文件目录(tnsname.ora,钱包) :param wallet_location: 钱包位置 :param wallet_password: 钱包密码 :param connection_string: 连接到ADB实例的连接字符串 :param metadata: 文档中使用的元数据 """ # Mandatory required arguments. self.query = query self.user = user self.password = password # Schema self.schema = schema # TNS connection Method self.tns_name = tns_name self.config_dir = config_dir # Wallet configuration is required for mTLS connection self.wallet_location = wallet_location self.wallet_password = wallet_password # Connection String connection method self.connection_string = connection_string # metadata column self.metadata = metadata # dsn self.dsn: Optional[str] self._set_dsn()
def _set_dsn(self) -> None: if self.connection_string: self.dsn = self.connection_string elif self.tns_name: self.dsn = self.tns_name def _run_query(self) -> List[Dict[str, Any]]: try: import oracledb except ImportError as e: raise ImportError( "Could not import oracledb, " "please install with 'pip install oracledb'" ) from e connect_param = {"user": self.user, "password": self.password, "dsn": self.dsn} if self.dsn == self.tns_name: connect_param["config_dir"] = self.config_dir if self.wallet_location and self.wallet_password: connect_param["wallet_location"] = self.wallet_location connect_param["wallet_password"] = self.wallet_password try: connection = oracledb.connect(**connect_param) cursor = connection.cursor() if self.schema: cursor.execute(f"alter session set current_schema={self.schema}") cursor.execute(self.query) columns = [col[0] for col in cursor.description] data = cursor.fetchall() data = [dict(zip(columns, row)) for row in data] except oracledb.DatabaseError as e: print("Got error while connecting: " + str(e)) # noqa: T201 data = [] finally: cursor.close() connection.close() return data
[docs] def load(self) -> List[Document]: data = self._run_query() documents = [] metadata_columns = self.metadata if self.metadata else [] for row in data: metadata = { key: value for key, value in row.items() if key in metadata_columns } doc = Document(page_content=str(row), metadata=metadata) documents.append(doc) return documents