Source code for langchain_community.vectorstores.oraclevs

from __future__ import annotations

import array
import functools
import hashlib
import json
import logging
import os
import uuid
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    Iterable,
    List,
    Optional,
    Tuple,
    Type,
    TypeVar,
    Union,
    cast,
)

if TYPE_CHECKING:
    from oracledb import Connection

import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore

from langchain_community.vectorstores.utils import (
    DistanceStrategy,
    maximal_marginal_relevance,
)

logger = logging.getLogger(__name__)
log_level = os.getenv("LOG_LEVEL", "ERROR").upper()
logging.basicConfig(
    level=getattr(logging, log_level),
    format="%(asctime)s - %(levelname)s - %(message)s",
)


# Define a type variable that can be any kind of function
T = TypeVar("T", bound=Callable[..., Any])


def _handle_exceptions(func: T) -> T:
    @functools.wraps(func)
    def wrapper(*args: Any, **kwargs: Any) -> Any:
        try:
            return func(*args, **kwargs)
        except RuntimeError as db_err:
            # Handle a known type of error (e.g., DB-related) specifically
            logger.exception("DB-related error occurred.")
            raise RuntimeError(
                "Failed due to a DB issue: {}".format(db_err)
            ) from db_err
        except ValueError as val_err:
            # Handle another known type of error specifically
            logger.exception("Validation error.")
            raise ValueError("Validation failed: {}".format(val_err)) from val_err
        except Exception as e:
            # Generic handler for all other exceptions
            logger.exception("An unexpected error occurred: {}".format(e))
            raise RuntimeError("Unexpected error: {}".format(e)) from e

    return cast(T, wrapper)


def _table_exists(client: Connection, table_name: str) -> bool:
    try:
        import oracledb
    except ImportError as e:
        raise ImportError(
            "Unable to import oracledb, please install with "
            "`pip install -U oracledb`."
        ) from e

    try:
        with client.cursor() as cursor:
            cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
            return True
    except oracledb.DatabaseError as ex:
        err_obj = ex.args
        if err_obj[0].code == 942:
            return False
        raise


@_handle_exceptions
def _index_exists(client: Connection, index_name: str) -> bool:
    # Check if the index exists
    query = """
        SELECT index_name 
        FROM all_indexes 
        WHERE upper(index_name) = upper(:idx_name)
        """

    with client.cursor() as cursor:
        # Execute the query
        cursor.execute(query, idx_name=index_name.upper())
        result = cursor.fetchone()

        # Check if the index exists
    return result is not None


def _get_distance_function(distance_strategy: DistanceStrategy) -> str:
    # Dictionary to map distance strategies to their corresponding function
    # names
    distance_strategy2function = {
        DistanceStrategy.EUCLIDEAN_DISTANCE: "EUCLIDEAN",
        DistanceStrategy.DOT_PRODUCT: "DOT",
        DistanceStrategy.COSINE: "COSINE",
    }

    # Attempt to return the corresponding distance function
    if distance_strategy in distance_strategy2function:
        return distance_strategy2function[distance_strategy]

    # If it's an unsupported distance strategy, raise an error
    raise ValueError(f"Unsupported distance strategy: {distance_strategy}")


def _get_index_name(base_name: str) -> str:
    unique_id = str(uuid.uuid4()).replace("-", "")
    return f"{base_name}_{unique_id}"


@_handle_exceptions
def _create_table(client: Connection, table_name: str, embedding_dim: int) -> None:
    cols_dict = {
        "id": "RAW(16) DEFAULT SYS_GUID() PRIMARY KEY",
        "text": "CLOB",
        "metadata": "CLOB",
        "embedding": f"vector({embedding_dim}, FLOAT32)",
    }

    if not _table_exists(client, table_name):
        with client.cursor() as cursor:
            ddl_body = ", ".join(
                f"{col_name} {col_type}" for col_name, col_type in cols_dict.items()
            )
            ddl = f"CREATE TABLE {table_name} ({ddl_body})"
            cursor.execute(ddl)
        logger.info("Table created successfully...")
    else:
        logger.info("Table already exists...")


[docs]@_handle_exceptions def create_index( client: Connection, vector_store: OracleVS, params: Optional[dict[str, Any]] = None, ) -> None: if params: if params["idx_type"] == "HNSW": _create_hnsw_index( client, vector_store.table_name, vector_store.distance_strategy, params ) elif params["idx_type"] == "IVF": _create_ivf_index( client, vector_store.table_name, vector_store.distance_strategy, params ) else: _create_hnsw_index( client, vector_store.table_name, vector_store.distance_strategy, params ) else: _create_hnsw_index( client, vector_store.table_name, vector_store.distance_strategy, params ) return
@_handle_exceptions def _create_hnsw_index( client: Connection, table_name: str, distance_strategy: DistanceStrategy, params: Optional[dict[str, Any]] = None, ) -> None: defaults = { "idx_name": "HNSW", "idx_type": "HNSW", "neighbors": 32, "efConstruction": 200, "accuracy": 90, "parallel": 8, } if params: config = params.copy() # Ensure compulsory parts are included for compulsory_key in ["idx_name", "parallel"]: if compulsory_key not in config: if compulsory_key == "idx_name": config[compulsory_key] = _get_index_name( str(defaults[compulsory_key]) ) else: config[compulsory_key] = defaults[compulsory_key] # Validate keys in config against defaults for key in config: if key not in defaults: raise ValueError(f"Invalid parameter: {key}") else: config = defaults # Base SQL statement idx_name = config["idx_name"] base_sql = ( f"create vector index {idx_name} on {table_name}(embedding) " f"ORGANIZATION INMEMORY NEIGHBOR GRAPH" ) # Optional parts depending on parameters accuracy_part = " WITH TARGET ACCURACY {accuracy}" if ("accuracy" in config) else "" distance_part = f" DISTANCE {_get_distance_function(distance_strategy)}" parameters_part = "" if "neighbors" in config and "efConstruction" in config: parameters_part = ( " parameters (type {idx_type}, neighbors {" "neighbors}, efConstruction {efConstruction})" ) elif "neighbors" in config and "efConstruction" not in config: config["efConstruction"] = defaults["efConstruction"] parameters_part = ( " parameters (type {idx_type}, neighbors {" "neighbors}, efConstruction {efConstruction})" ) elif "neighbors" not in config and "efConstruction" in config: config["neighbors"] = defaults["neighbors"] parameters_part = ( " parameters (type {idx_type}, neighbors {" "neighbors}, efConstruction {efConstruction})" ) # Always included part for parallel parallel_part = " parallel {parallel}" # Combine all parts ddl_assembly = ( base_sql + accuracy_part + distance_part + parameters_part + parallel_part ) # Format the SQL with values from the params dictionary ddl = ddl_assembly.format(**config) # Check if the index exists if not _index_exists(client, config["idx_name"]): with client.cursor() as cursor: cursor.execute(ddl) logger.info("Index created successfully...") else: logger.info("Index already exists...") @_handle_exceptions def _create_ivf_index( client: Connection, table_name: str, distance_strategy: DistanceStrategy, params: Optional[dict[str, Any]] = None, ) -> None: # Default configuration defaults = { "idx_name": "IVF", "idx_type": "IVF", "neighbor_part": 32, "accuracy": 90, "parallel": 8, } if params: config = params.copy() # Ensure compulsory parts are included for compulsory_key in ["idx_name", "parallel"]: if compulsory_key not in config: if compulsory_key == "idx_name": config[compulsory_key] = _get_index_name( str(defaults[compulsory_key]) ) else: config[compulsory_key] = defaults[compulsory_key] # Validate keys in config against defaults for key in config: if key not in defaults: raise ValueError(f"Invalid parameter: {key}") else: config = defaults # Base SQL statement idx_name = config["idx_name"] base_sql = ( f"CREATE VECTOR INDEX {idx_name} ON {table_name}(embedding) " f"ORGANIZATION NEIGHBOR PARTITIONS" ) # Optional parts depending on parameters accuracy_part = " WITH TARGET ACCURACY {accuracy}" if ("accuracy" in config) else "" distance_part = f" DISTANCE {_get_distance_function(distance_strategy)}" parameters_part = "" if "idx_type" in config and "neighbor_part" in config: parameters_part = ( f" PARAMETERS (type {config['idx_type']}, neighbor" f" partitions {config['neighbor_part']})" ) # Always included part for parallel parallel_part = f" PARALLEL {config['parallel']}" # Combine all parts ddl_assembly = ( base_sql + accuracy_part + distance_part + parameters_part + parallel_part ) # Format the SQL with values from the params dictionary ddl = ddl_assembly.format(**config) # Check if the index exists if not _index_exists(client, config["idx_name"]): with client.cursor() as cursor: cursor.execute(ddl) logger.info("Index created successfully...") else: logger.info("Index already exists...")
[docs]@_handle_exceptions def drop_table_purge(client: Connection, table_name: str) -> None: if _table_exists(client, table_name): cursor = client.cursor() with cursor: ddl = f"DROP TABLE {table_name} PURGE" cursor.execute(ddl) logger.info("Table dropped successfully...") else: logger.info("Table not found...") return
[docs]@_handle_exceptions def drop_index_if_exists(client: Connection, index_name: str) -> None: if _index_exists(client, index_name): drop_query = f"DROP INDEX {index_name}" with client.cursor() as cursor: cursor.execute(drop_query) logger.info(f"Index {index_name} has been dropped.") else: logger.exception(f"Index {index_name} does not exist.") return
[docs]class OracleVS(VectorStore): """`OracleVS` 向量存储。 要使用,您应该同时具备以下条件: - 安装了``oracledb`` Python包 - 与部署了搜索索引的 OracleDBCluster 相关联的连接字符串 示例: .. code-block:: python from langchain.vectorstores import OracleVS from langchain.embeddings.openai import OpenAIEmbeddings import oracledb with oracledb.connect(user = user, passwd = pwd, dsn = dsn) as connection: print ("数据库版本:", connection.version) embeddings = OpenAIEmbeddings() query = "" vectors = OracleVS(connection, table_name, embeddings, query)"""
[docs] def __init__( self, client: Connection, embedding_function: Union[ Callable[[str], List[float]], Embeddings, ], table_name: str, distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE, query: Optional[str] = "What is a Oracle database", params: Optional[Dict[str, Any]] = None, ): try: import oracledb except ImportError as e: raise ImportError( "Unable to import oracledb, please install with " "`pip install -U oracledb`." ) from e try: """Initialize with oracledb client.""" self.client = client """Initialize with necessary components.""" if not isinstance(embedding_function, Embeddings): logger.warning( "`embedding_function` is expected to be an Embeddings " "object, support " "for passing in a function will soon be removed." ) self.embedding_function = embedding_function self.query = query embedding_dim = self.get_embedding_dimension() self.table_name = table_name self.distance_strategy = distance_strategy self.params = params _create_table(client, table_name, embedding_dim) except oracledb.DatabaseError as db_err: logger.exception(f"Database error occurred while create table: {db_err}") raise RuntimeError( "Failed to create table due to a database error." ) from db_err except ValueError as val_err: logger.exception(f"Validation error: {val_err}") raise RuntimeError( "Failed to create table due to a validation error." ) from val_err except Exception as ex: logger.exception("An unexpected error occurred while creating the index.") raise RuntimeError( "Failed to create table due to an unexpected error." ) from ex
@property def embeddings(self) -> Optional[Embeddings]: """返回一个属性,该属性返回一个Embeddings实例embedding_function,如果它是Embeddings的一个实例,则返回Embeddings,否则返回None。 返回: Optional[Embeddings]: 如果它是Embeddings的一个实例,则返回嵌入函数,否则返回None。 """ return ( self.embedding_function if isinstance(self.embedding_function, Embeddings) else None )
[docs] def get_embedding_dimension(self) -> int: # Embed the single document by wrapping it in a list embedded_document = self._embed_documents( [self.query if self.query is not None else ""] ) # Get the first (and only) embedding's dimension return len(embedded_document[0])
def _embed_documents(self, texts: List[str]) -> List[List[float]]: if isinstance(self.embedding_function, Embeddings): return self.embedding_function.embed_documents(texts) elif callable(self.embedding_function): return [self.embedding_function(text) for text in texts] else: raise TypeError( "The embedding_function is neither Embeddings nor callable." ) def _embed_query(self, text: str) -> List[float]: if isinstance(self.embedding_function, Embeddings): return self.embedding_function.embed_query(text) else: return self.embedding_function(text)
[docs] @_handle_exceptions def add_texts( self, texts: Iterable[str], metadatas: Optional[List[Dict[Any, Any]]] = None, ids: Optional[List[str]] = None, **kwargs: Any, ) -> List[str]: """将更多文本添加到向量存储索引中。 参数: texts:要添加到向量存储中的字符串的可迭代对象。 metadatas:与文本相关联的元数据的可选列表。 ids:要添加到向量存储中的文本的可选id列表。 kwargs:向量存储特定参数。 """ texts = list(texts) if ids: # If ids are provided, hash them to maintain consistency processed_ids = [ hashlib.sha256(_id.encode()).hexdigest()[:16].upper() for _id in ids ] elif metadatas and all("id" in metadata for metadata in metadatas): # If no ids are provided but metadatas with ids are, generate # ids from metadatas processed_ids = [ hashlib.sha256(metadata["id"].encode()).hexdigest()[:16].upper() for metadata in metadatas ] else: # Generate new ids if none are provided generated_ids = [ str(uuid.uuid4()) for _ in texts ] # uuid4 is more standard for random UUIDs processed_ids = [ hashlib.sha256(_id.encode()).hexdigest()[:16].upper() for _id in generated_ids ] embeddings = self._embed_documents(texts) if not metadatas: metadatas = [{} for _ in texts] docs = [ (id_, text, json.dumps(metadata), array.array("f", embedding)) for id_, text, metadata, embedding in zip( processed_ids, texts, metadatas, embeddings ) ] with self.client.cursor() as cursor: cursor.executemany( f"INSERT INTO {self.table_name} (id, text, metadata, " f"embedding) VALUES (:1, :2, :3, :4)", docs, ) self.client.commit() return processed_ids
[docs] def similarity_search_by_vector( self, embedding: List[float], k: int = 4, filter: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> List[Document]: docs_and_scores = self.similarity_search_by_vector_with_relevance_scores( embedding=embedding, k=k, filter=filter, **kwargs ) return [doc for doc, _ in docs_and_scores]
[docs] def similarity_search_with_score( self, query: str, k: int = 4, filter: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: """返回与查询最相似的文档。""" if isinstance(self.embedding_function, Embeddings): embedding = self.embedding_function.embed_query(query) docs_and_scores = self.similarity_search_by_vector_with_relevance_scores( embedding=embedding, k=k, filter=filter, **kwargs ) return docs_and_scores
@_handle_exceptions def _get_clob_value(self, result: Any) -> str: try: import oracledb except ImportError as e: raise ImportError( "Unable to import oracledb, please install with " "`pip install -U oracledb`." ) from e clob_value = "" if result: if isinstance(result, oracledb.LOB): raw_data = result.read() if isinstance(raw_data, bytes): clob_value = raw_data.decode( "utf-8" ) # Specify the correct encoding else: clob_value = raw_data elif isinstance(result, str): clob_value = result else: raise Exception("Unexpected type:", type(result)) return clob_value
[docs] @_handle_exceptions def similarity_search_by_vector_with_relevance_scores( self, embedding: List[float], k: int = 4, filter: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: docs_and_scores = [] embedding_arr = array.array("f", embedding) query = f""" SELECT id, text, metadata, vector_distance(embedding, :embedding, {_get_distance_function(self.distance_strategy)}) as distance FROM {self.table_name} ORDER BY distance FETCH APPROX FIRST {k} ROWS ONLY """ # Execute the query with self.client.cursor() as cursor: cursor.execute(query, embedding=embedding_arr) results = cursor.fetchall() # Filter results if filter is provided for result in results: metadata = json.loads( self._get_clob_value(result[2]) if result[2] is not None else "{}" ) # Apply filtering based on the 'filter' dictionary if filter: if all(metadata.get(key) in value for key, value in filter.items()): doc = Document( page_content=( self._get_clob_value(result[1]) if result[1] is not None else "" ), metadata=metadata, ) distance = result[3] docs_and_scores.append((doc, distance)) else: doc = Document( page_content=( self._get_clob_value(result[1]) if result[1] is not None else "" ), metadata=metadata, ) distance = result[3] docs_and_scores.append((doc, distance)) return docs_and_scores
[docs] @_handle_exceptions def similarity_search_by_vector_returning_embeddings( self, embedding: List[float], k: int, filter: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> List[Tuple[Document, float, np.ndarray[np.float32, Any]]]: documents = [] embedding_arr = array.array("f", embedding) query = f""" SELECT id, text, metadata, vector_distance(embedding, :embedding, {_get_distance_function( self.distance_strategy)}) as distance, embedding FROM {self.table_name} ORDER BY distance FETCH APPROX FIRST {k} ROWS ONLY """ # Execute the query with self.client.cursor() as cursor: cursor.execute(query, embedding=embedding_arr) results = cursor.fetchall() for result in results: page_content_str = self._get_clob_value(result[1]) metadata_str = self._get_clob_value(result[2]) metadata = json.loads(metadata_str) # Apply filter if provided and matches; otherwise, add all # documents if not filter or all( metadata.get(key) in value for key, value in filter.items() ): document = Document( page_content=page_content_str, metadata=metadata ) distance = result[3] # Assuming result[4] is already in the correct format; # adjust if necessary current_embedding = ( np.array(result[4], dtype=np.float32) if result[4] else np.empty(0, dtype=np.float32) ) documents.append((document, distance, current_embedding)) return documents # type: ignore
[docs] @_handle_exceptions def max_marginal_relevance_search_with_score_by_vector( self, embedding: List[float], *, k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, filter: Optional[Dict[str, Any]] = None, ) -> List[Tuple[Document, float]]: """使用最大边际相关性返回选定的文档及其相似性分数。 最大边际相关性优化了与查询的相似性和所选文档之间的多样性。 参数: self: 类的一个实例 embedding: 要查找相似文档的嵌入。 k: 要返回的文档数量。默认为4。 fetch_k: 在过滤到通过MMR算法传递之前要获取的文档数量。 filter: (可选[Dict[str, str]]):按元数据过滤。默认为无。 lambda_mult: 0到1之间的数字,确定结果之间多样性的程度,其中0对应最大多样性,1对应最小多样性。默认为0.5。 返回: 通过最大边际相关性选择的文档和相似性分数的列表,以及每个文档的分数。 """ # Fetch documents and their scores docs_scores_embeddings = self.similarity_search_by_vector_returning_embeddings( embedding, fetch_k, filter=filter ) # Assuming documents_with_scores is a list of tuples (Document, score) # If you need to split documents and scores for processing (e.g., # for MMR calculation) documents, scores, embeddings = ( zip(*docs_scores_embeddings) if docs_scores_embeddings else ([], [], []) ) # Assume maximal_marginal_relevance method accepts embeddings and # scores, and returns indices of selected docs mmr_selected_indices = maximal_marginal_relevance( np.array(embedding, dtype=np.float32), list(embeddings), k=k, lambda_mult=lambda_mult, ) # Filter documents based on MMR-selected indices and map scores mmr_selected_documents_with_scores = [ (documents[i], scores[i]) for i in mmr_selected_indices ] return mmr_selected_documents_with_scores
[docs] @_handle_exceptions def max_marginal_relevance_search_by_vector( self, embedding: List[float], k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, filter: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> List[Document]: """返回使用最大边际相关性选择的文档。 最大边际相关性优化了与查询的相似性和所选文档之间的多样性。 参数: self:类的实例 embedding:要查找相似文档的嵌入。 k:要返回的文档数量。默认为4。 fetch_k:要获取以传递给MMR算法的文档数量。 lambda_mult:0到1之间的数字,确定结果之间多样性的程度,0表示最大多样性,1表示最小多样性。默认为0.5。 filter:可选[Dict[str, Any]] **kwargs:任何 返回: 由最大边际相关性选择的文档列表。 """ docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector( embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter ) return [doc for doc, _ in docs_and_scores]
[docs] @_handle_exceptions def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None: """根据向量ID删除。 参数: self:类的一个实例 ids:要删除的ID列表。 **kwargs """ if ids is None: raise ValueError("No ids provided to delete.") # Compute SHA-256 hashes of the ids and truncate them hashed_ids = [ hashlib.sha256(_id.encode()).hexdigest()[:16].upper() for _id in ids ] # Constructing the SQL statement with individual placeholders placeholders = ", ".join([":id" + str(i + 1) for i in range(len(hashed_ids))]) ddl = f"DELETE FROM {self.table_name} WHERE id IN ({placeholders})" # Preparing bind variables bind_vars = { f"id{i}": hashed_id for i, hashed_id in enumerate(hashed_ids, start=1) } with self.client.cursor() as cursor: cursor.execute(ddl, bind_vars) self.client.commit()
[docs] @classmethod @_handle_exceptions def from_texts( cls: Type[OracleVS], texts: Iterable[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, **kwargs: Any, ) -> OracleVS: """返回从文本和嵌入初始化的VectorStore。""" client = kwargs.get("client") if client is None: raise ValueError("client parameter is required...") params = kwargs.get("params", {}) table_name = str(kwargs.get("table_name", "langchain")) distance_strategy = cast( DistanceStrategy, kwargs.get("distance_strategy", None) ) if not isinstance(distance_strategy, DistanceStrategy): raise TypeError( f"Expected DistanceStrategy got " f"{type(distance_strategy).__name__} " ) query = kwargs.get("query", "What is a Oracle database") drop_table_purge(client, table_name) vss = cls( client=client, embedding_function=embedding, table_name=table_name, distance_strategy=distance_strategy, query=query, params=params, ) vss.add_texts(texts=list(texts), metadatas=metadatas) return vss