Source code for langchain_community.vectorstores.azuresearch

from __future__ import annotations

import base64
import json
import logging
import uuid
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    ClassVar,
    Collection,
    Dict,
    Iterable,
    List,
    Literal,
    Optional,
    Tuple,
    Type,
    Union,
)

import numpy as np
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import root_validator
from langchain_core.retrievers import BaseRetriever
from langchain_core.utils import get_from_env
from langchain_core.vectorstores import VectorStore

logger = logging.getLogger()

if TYPE_CHECKING:
    from azure.search.documents import SearchClient
    from azure.search.documents.indexes.models import (
        CorsOptions,
        ScoringProfile,
        SearchField,
        SemanticConfiguration,
        VectorSearch,
    )

# Allow overriding field names for Azure Search
FIELDS_ID = get_from_env(
    key="AZURESEARCH_FIELDS_ID", env_key="AZURESEARCH_FIELDS_ID", default="id"
)
FIELDS_CONTENT = get_from_env(
    key="AZURESEARCH_FIELDS_CONTENT",
    env_key="AZURESEARCH_FIELDS_CONTENT",
    default="content",
)
FIELDS_CONTENT_VECTOR = get_from_env(
    key="AZURESEARCH_FIELDS_CONTENT_VECTOR",
    env_key="AZURESEARCH_FIELDS_CONTENT_VECTOR",
    default="content_vector",
)
FIELDS_METADATA = get_from_env(
    key="AZURESEARCH_FIELDS_TAG", env_key="AZURESEARCH_FIELDS_TAG", default="metadata"
)

MAX_UPLOAD_BATCH_SIZE = 1000


def _get_search_client(
    endpoint: str,
    key: str,
    index_name: str,
    semantic_configuration_name: Optional[str] = None,
    fields: Optional[List[SearchField]] = None,
    vector_search: Optional[VectorSearch] = None,
    semantic_configurations: Optional[
        Union[SemanticConfiguration, List[SemanticConfiguration]]
    ] = None,
    scoring_profiles: Optional[List[ScoringProfile]] = None,
    default_scoring_profile: Optional[str] = None,
    default_fields: Optional[List[SearchField]] = None,
    user_agent: Optional[str] = "langchain",
    cors_options: Optional[CorsOptions] = None,
) -> SearchClient:
    from azure.core.credentials import AzureKeyCredential
    from azure.core.exceptions import ResourceNotFoundError
    from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential
    from azure.search.documents import SearchClient
    from azure.search.documents.indexes import SearchIndexClient
    from azure.search.documents.indexes.models import (
        ExhaustiveKnnAlgorithmConfiguration,
        ExhaustiveKnnParameters,
        HnswAlgorithmConfiguration,
        HnswParameters,
        SearchIndex,
        SemanticConfiguration,
        SemanticField,
        SemanticPrioritizedFields,
        SemanticSearch,
        VectorSearch,
        VectorSearchAlgorithmKind,
        VectorSearchAlgorithmMetric,
        VectorSearchProfile,
    )

    default_fields = default_fields or []
    if key is None:
        credential = DefaultAzureCredential()
    elif key.upper() == "INTERACTIVE":
        credential = InteractiveBrowserCredential()
        credential.get_token("https://search.azure.com/.default")
    else:
        credential = AzureKeyCredential(key)
    index_client: SearchIndexClient = SearchIndexClient(
        endpoint=endpoint, credential=credential, user_agent=user_agent
    )
    try:
        index_client.get_index(name=index_name)
    except ResourceNotFoundError:
        # Fields configuration
        if fields is not None:
            # Check mandatory fields
            fields_types = {f.name: f.type for f in fields}
            mandatory_fields = {df.name: df.type for df in default_fields}
            # Check for missing keys
            missing_fields = {
                key: mandatory_fields[key]
                for key, value in set(mandatory_fields.items())
                - set(fields_types.items())
            }
            if len(missing_fields) > 0:
                # Helper for formatting field information for each missing field.
                def fmt_err(x: str) -> str:
                    return (
                        f"{x} current type: '{fields_types.get(x, 'MISSING')}'. "
                        f"It has to be '{mandatory_fields.get(x)}' or you can point "
                        f"to a different '{mandatory_fields.get(x)}' field name by "
                        f"using the env variable 'AZURESEARCH_FIELDS_{x.upper()}'"
                    )

                error = "\n".join([fmt_err(x) for x in missing_fields])
                raise ValueError(
                    f"You need to specify at least the following fields "
                    f"{missing_fields} or provide alternative field names in the env "
                    f"variables.\n\n{error}"
                )
        else:
            fields = default_fields
        # Vector search configuration
        if vector_search is None:
            vector_search = VectorSearch(
                algorithms=[
                    HnswAlgorithmConfiguration(
                        name="default",
                        kind=VectorSearchAlgorithmKind.HNSW,
                        parameters=HnswParameters(
                            m=4,
                            ef_construction=400,
                            ef_search=500,
                            metric=VectorSearchAlgorithmMetric.COSINE,
                        ),
                    ),
                    ExhaustiveKnnAlgorithmConfiguration(
                        name="default_exhaustive_knn",
                        kind=VectorSearchAlgorithmKind.EXHAUSTIVE_KNN,
                        parameters=ExhaustiveKnnParameters(
                            metric=VectorSearchAlgorithmMetric.COSINE
                        ),
                    ),
                ],
                profiles=[
                    VectorSearchProfile(
                        name="myHnswProfile",
                        algorithm_configuration_name="default",
                    ),
                    VectorSearchProfile(
                        name="myExhaustiveKnnProfile",
                        algorithm_configuration_name="default_exhaustive_knn",
                    ),
                ],
            )

        # Create the semantic settings with the configuration
        if semantic_configurations:
            if not isinstance(semantic_configurations, list):
                semantic_configurations = [semantic_configurations]
            semantic_search = SemanticSearch(
                configurations=semantic_configurations,
                default_configuration_name=semantic_configuration_name,
            )
        elif semantic_configuration_name:
            # use default semantic configuration
            semantic_configuration = SemanticConfiguration(
                name=semantic_configuration_name,
                prioritized_fields=SemanticPrioritizedFields(
                    content_fields=[SemanticField(field_name=FIELDS_CONTENT)],
                ),
            )
            semantic_search = SemanticSearch(configurations=[semantic_configuration])
        else:
            # don't use semantic search
            semantic_search = None

        # Create the search index with the semantic settings and vector search
        index = SearchIndex(
            name=index_name,
            fields=fields,
            vector_search=vector_search,
            semantic_search=semantic_search,
            scoring_profiles=scoring_profiles,
            default_scoring_profile=default_scoring_profile,
            cors_options=cors_options,
        )
        index_client.create_index(index)
    # Create the search client
    return SearchClient(
        endpoint=endpoint,
        index_name=index_name,
        credential=credential,
        user_agent=user_agent,
    )


[docs]class AzureSearch(VectorStore): """`Azure Cognitive Search` 向量存储。"""
[docs] def __init__( self, azure_search_endpoint: str, azure_search_key: str, index_name: str, embedding_function: Union[Callable, Embeddings], search_type: str = "hybrid", semantic_configuration_name: Optional[str] = None, fields: Optional[List[SearchField]] = None, vector_search: Optional[VectorSearch] = None, semantic_configurations: Optional[ Union[SemanticConfiguration, List[SemanticConfiguration]] ] = None, scoring_profiles: Optional[List[ScoringProfile]] = None, default_scoring_profile: Optional[str] = None, cors_options: Optional[CorsOptions] = None, **kwargs: Any, ): from azure.search.documents.indexes.models import ( SearchableField, SearchField, SearchFieldDataType, SimpleField, ) """Initialize with necessary components.""" # Initialize base class self.embedding_function = embedding_function if isinstance(self.embedding_function, Embeddings): self.embed_query = self.embedding_function.embed_query else: self.embed_query = self.embedding_function default_fields = [ SimpleField( name=FIELDS_ID, type=SearchFieldDataType.String, key=True, filterable=True, ), SearchableField( name=FIELDS_CONTENT, type=SearchFieldDataType.String, ), SearchField( name=FIELDS_CONTENT_VECTOR, type=SearchFieldDataType.Collection(SearchFieldDataType.Single), searchable=True, vector_search_dimensions=len(self.embed_query("Text")), vector_search_profile_name="myHnswProfile", ), SearchableField( name=FIELDS_METADATA, type=SearchFieldDataType.String, ), ] user_agent = "langchain" if "user_agent" in kwargs and kwargs["user_agent"]: user_agent += " " + kwargs["user_agent"] self.client = _get_search_client( azure_search_endpoint, azure_search_key, index_name, semantic_configuration_name=semantic_configuration_name, fields=fields, vector_search=vector_search, semantic_configurations=semantic_configurations, scoring_profiles=scoring_profiles, default_scoring_profile=default_scoring_profile, default_fields=default_fields, user_agent=user_agent, cors_options=cors_options, ) self.search_type = search_type self.semantic_configuration_name = semantic_configuration_name self.fields = fields if fields else default_fields
@property def embeddings(self) -> Optional[Embeddings]: # TODO: Support embedding object directly return None
[docs] def add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, **kwargs: Any, ) -> List[str]: """将文本数据添加到现有索引中。""" keys = kwargs.get("keys") ids = [] # batching support if embedding function is an Embeddings object if isinstance(self.embedding_function, Embeddings): try: embeddings = self.embedding_function.embed_documents(texts) # type: ignore[arg-type] except NotImplementedError: embeddings = [self.embedding_function.embed_query(x) for x in texts] else: embeddings = [self.embedding_function(x) for x in texts] if len(embeddings) == 0: logger.debug("Nothing to insert, skipping.") return [] # Write data to index data = [] for i, text in enumerate(texts): # Use provided key otherwise use default key key = keys[i] if keys else str(uuid.uuid4()) # Encoding key for Azure Search valid characters key = base64.urlsafe_b64encode(bytes(key, "utf-8")).decode("ascii") metadata = metadatas[i] if metadatas else {} # Add data to index # Additional metadata to fields mapping doc = { "@search.action": "upload", FIELDS_ID: key, FIELDS_CONTENT: text, FIELDS_CONTENT_VECTOR: np.array( embeddings[i], dtype=np.float32 ).tolist(), FIELDS_METADATA: json.dumps(metadata), } if metadata: additional_fields = { k: v for k, v in metadata.items() if k in [x.name for x in self.fields] } doc.update(additional_fields) data.append(doc) ids.append(key) # Upload data in batches if len(data) == MAX_UPLOAD_BATCH_SIZE: response = self.client.upload_documents(documents=data) # Check if all documents were successfully uploaded if not all([r.succeeded for r in response]): raise Exception(response) # Reset data data = [] # Considering case where data is an exact multiple of batch-size entries if len(data) == 0: return ids # Upload data to index response = self.client.upload_documents(documents=data) # Check if all documents were successfully uploaded if all([r.succeeded for r in response]): return ids else: raise Exception(response)
[docs] def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> bool: """根据向量ID删除。 参数: ids:要删除的ID列表。 返回: bool:如果删除成功则为True,否则为False。 """ if ids: res = self.client.delete_documents([{"id": i} for i in ids]) return len(res) > 0 else: return False
[docs] def similarity_search_with_relevance_scores( self, query: str, k: int = 4, **kwargs: Any ) -> List[Tuple[Document, float]]: score_threshold = kwargs.pop("score_threshold", None) result = self.vector_search_with_score(query, k=k, **kwargs) return ( result if score_threshold is None else [r for r in result if r[1] >= score_threshold] )
[docs] def vector_search_with_score( self, query: str, k: int = 4, filters: Optional[str] = None ) -> List[Tuple[Document, float]]: """返回与查询最相似的文档。 参数: query:要查找类似文档的文本。 k:要返回的文档数量。默认为4。 返回: 与查询最相似的文档列表,以及每个文档的分数。 """ from azure.search.documents.models import VectorizedQuery results = self.client.search( search_text="", vector_queries=[ VectorizedQuery( vector=np.array(self.embed_query(query), dtype=np.float32).tolist(), k_nearest_neighbors=k, fields=FIELDS_CONTENT_VECTOR, ) ], filter=filters, top=k, ) # Convert results to Document objects docs = [ ( Document( page_content=result.pop(FIELDS_CONTENT), metadata=json.loads(result[FIELDS_METADATA]) if FIELDS_METADATA in result else { k: v for k, v in result.items() if k != FIELDS_CONTENT_VECTOR }, ), float(result["@search.score"]), ) for result in results ] return docs
[docs] def hybrid_search_with_score( self, query: str, k: int = 4, filters: Optional[str] = None ) -> List[Tuple[Document, float]]: """使用混合查询返回与查询最相似的文档。 参数: query:要查找与之相似的文档的文本。 k:要返回的文档数量。默认为4。 返回: 返回与查询最相似的文档列表,以及每个文档的得分。 """ from azure.search.documents.models import VectorizedQuery results = self.client.search( search_text=query, vector_queries=[ VectorizedQuery( vector=np.array(self.embed_query(query), dtype=np.float32).tolist(), k_nearest_neighbors=k, fields=FIELDS_CONTENT_VECTOR, ) ], filter=filters, top=k, ) # Convert results to Document objects docs = [ ( Document( page_content=result.pop(FIELDS_CONTENT), metadata=json.loads(result[FIELDS_METADATA]) if FIELDS_METADATA in result else { k: v for k, v in result.items() if k != FIELDS_CONTENT_VECTOR }, ), float(result["@search.score"]), ) for result in results ] return docs
[docs] def hybrid_search_with_relevance_scores( self, query: str, k: int = 4, **kwargs: Any ) -> List[Tuple[Document, float]]: score_threshold = kwargs.pop("score_threshold", None) result = self.hybrid_search_with_score(query, k=k, **kwargs) return ( result if score_threshold is None else [r for r in result if r[1] >= score_threshold] )
[docs] def semantic_hybrid_search_with_score( self, query: str, k: int = 4, score_type: Literal["score", "reranker_score"] = "score", **kwargs: Any, ) -> List[Tuple[Document, float]]: """返回与查询文本最相似的索引文档。 参数: query(str):要查找相似文档的查询文本。 k(int):要返回的文档数量。默认为4。 score_type:必须是“score”或“reranker_score”之一。 默认为“score”。 返回: List[Tuple[Document, float]]:文档及其相应分数的列表。 """ score_threshold = kwargs.pop("score_threshold", None) docs_and_scores = self.semantic_hybrid_search_with_score_and_rerank( query, k=k, filters=kwargs.get("filters", None) ) if score_type == "score": return [ (doc, score) for doc, score, _ in docs_and_scores if score_threshold is None or score >= score_threshold ] elif score_type == "reranker_score": return [ (doc, reranker_score) for doc, _, reranker_score in docs_and_scores if score_threshold is None or reranker_score >= score_threshold ]
[docs] def semantic_hybrid_search_with_score_and_rerank( self, query: str, k: int = 4, filters: Optional[str] = None ) -> List[Tuple[Document, float, float]]: """使用混合查询返回与查询最相似的文档。 参数: query:要查找与之相似的文档的文本。 k:要返回的文档数量。默认为4。 返回: 返回与查询最相似的文档列表,以及每个文档的得分。 """ from azure.search.documents.models import VectorizedQuery results = self.client.search( search_text=query, vector_queries=[ VectorizedQuery( vector=np.array(self.embed_query(query), dtype=np.float32).tolist(), k_nearest_neighbors=k, fields=FIELDS_CONTENT_VECTOR, ) ], filter=filters, query_type="semantic", semantic_configuration_name=self.semantic_configuration_name, query_caption="extractive", query_answer="extractive", top=k, ) # Get Semantic Answers semantic_answers = results.get_answers() or [] semantic_answers_dict: Dict = {} for semantic_answer in semantic_answers: semantic_answers_dict[semantic_answer.key] = { "text": semantic_answer.text, "highlights": semantic_answer.highlights, } # Convert results to Document objects docs = [ ( Document( page_content=result.pop(FIELDS_CONTENT), metadata={ **( json.loads(result[FIELDS_METADATA]) if FIELDS_METADATA in result else { k: v for k, v in result.items() if k != FIELDS_CONTENT_VECTOR } ), **{ "captions": { "text": result.get("@search.captions", [{}])[0].text, "highlights": result.get("@search.captions", [{}])[ 0 ].highlights, } if result.get("@search.captions") else {}, "answers": semantic_answers_dict.get( result.get(FIELDS_ID, ""), "", ), }, }, ), float(result["@search.score"]), float(result["@search.reranker_score"]), ) for result in results ] return docs
[docs] @classmethod def from_texts( cls: Type[AzureSearch], texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, azure_search_endpoint: str = "", azure_search_key: str = "", index_name: str = "langchain-index", fields: Optional[List[SearchField]] = None, **kwargs: Any, ) -> AzureSearch: # Creating a new Azure Search instance azure_search = cls( azure_search_endpoint, azure_search_key, index_name, embedding, fields=fields, ) azure_search.add_texts(texts, metadatas, **kwargs) return azure_search
[docs] def as_retriever(self, **kwargs: Any) -> AzureSearchVectorStoreRetriever: # type: ignore """返回从此VectorStore初始化的AzureSearchVectorStoreRetriever。 参数: search_type (Optional[str]): 定义Retriever应执行的搜索类型。 可以是"similarity"(默认值), "hybrid", 或 "semantic_hybrid"。 search_kwargs (Optional[Dict]): 传递给搜索函数的关键字参数。可以包括: k: 返回的文档数量(默认值: 4) score_threshold: 相似性分数阈值的最小相关性阈值 fetch_k: 传递给MMR算法的文档数量(默认值: 20) lambda_mult: MMR返回结果的多样性;1表示最小多样性,0表示最大多样性(默认值: 0.5) filter: 根据文档元数据进行过滤 返回: AzureSearchVectorStoreRetriever: 用于VectorStore的Retriever类。 """ tags = kwargs.pop("tags", None) or [] tags.extend(self._get_retriever_tags()) return AzureSearchVectorStoreRetriever(vectorstore=self, **kwargs, tags=tags)
[docs]class AzureSearchVectorStoreRetriever(BaseRetriever): """使用`Azure Cognitive Search`的检索器。""" vectorstore: AzureSearch """用于查找相似文档的Azure搜索实例。""" search_type: str = "hybrid" """要执行的搜索类型。选项包括"相似性"、"混合"、"语义混合"、"相似性分数阈值"、"混合分数阈值"或"语义混合分数阈值"。""" k: int = 4 """返回的文档数量。""" allowed_search_types: ClassVar[Collection[str]] = ( "similarity", "similarity_score_threshold", "hybrid", "hybrid_score_threshold", "semantic_hybrid", "semantic_hybrid_score_threshold", ) class Config: """此pydantic对象的配置。""" arbitrary_types_allowed = True @root_validator() def validate_search_type(cls, values: Dict) -> Dict: """验证搜索类型。""" if "search_type" in values: search_type = values["search_type"] if search_type not in cls.allowed_search_types: raise ValueError( f"search_type of {search_type} not allowed. Valid values are: " f"{cls.allowed_search_types}" ) return values def _get_relevant_documents( self, query: str, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any, ) -> List[Document]: if self.search_type == "similarity": docs = self.vectorstore.vector_search(query, k=self.k, **kwargs) elif self.search_type == "similarity_score_threshold": docs = [ doc for doc, _ in self.vectorstore.similarity_search_with_relevance_scores( query, k=self.k, **kwargs ) ] elif self.search_type == "hybrid": docs = self.vectorstore.hybrid_search(query, k=self.k, **kwargs) elif self.search_type == "hybrid_score_threshold": docs = [ doc for doc, _ in self.vectorstore.hybrid_search_with_relevance_scores( query, k=self.k, **kwargs ) ] elif self.search_type == "semantic_hybrid": docs = self.vectorstore.semantic_hybrid_search(query, k=self.k, **kwargs) elif self.search_type == "semantic_hybrid_score_threshold": docs = [ doc for doc, _ in self.vectorstore.semantic_hybrid_search_with_score( query, k=self.k, **kwargs ) ] else: raise ValueError(f"search_type of {self.search_type} not allowed.") return docs