Source code for langchain_community.retrievers.google_vertex_ai_search

"""Retriever的Google Vertex AI Search包装器。"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple

from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
from langchain_core.retrievers import BaseRetriever
from langchain_core.utils import get_from_dict_or_env

from langchain_community.utilities.vertexai import get_client_info

if TYPE_CHECKING:
    from google.api_core.client_options import ClientOptions
    from google.cloud.discoveryengine_v1beta import (
        ConversationalSearchServiceClient,
        SearchRequest,
        SearchResult,
        SearchServiceClient,
    )


class _BaseGoogleVertexAISearchRetriever(BaseModel):
    project_id: str
    """谷歌云项目ID。"""
    data_store_id: Optional[str] = None
    """Vertex AI搜索数据存储ID。"""
    search_engine_id: Optional[str] = None
    """Vertex AI搜索应用程序ID。"""
    location_id: str = "global"
    """Vertex AI搜索数据存储位置。"""
    serving_config_id: str = "default_config"
    """Vertex AI搜索服务配置ID。"""
    credentials: Any = None
    """默认的自定义凭据(google.auth.credentials.Credentials)用于进行API调用时使用。如果未提供,则将从环境中确定凭据。"""
    engine_data_type: int = Field(default=0, ge=0, le=3)
    """定义了Vertex AI搜索应用数据类型
    0 - 非结构化数据 
    1 - 结构化数据
    2 - 网站数据
    3 - 混合搜索"""

    @root_validator(pre=True)
    def validate_environment(cls, values: Dict) -> Dict:
        """验证环境。"""
        try:
            from google.cloud import discoveryengine_v1beta  # noqa: F401
        except ImportError as exc:
            raise ImportError(
                "google.cloud.discoveryengine is not installed."
                "Please install it with pip install "
                "google-cloud-discoveryengine>=0.11.10"
            ) from exc
        try:
            from google.api_core.exceptions import InvalidArgument  # noqa: F401
        except ImportError as exc:
            raise ImportError(
                "google.api_core.exceptions is not installed. "
                "Please install it with pip install google-api-core"
            ) from exc

        values["project_id"] = get_from_dict_or_env(values, "project_id", "PROJECT_ID")

        try:
            values["data_store_id"] = get_from_dict_or_env(
                values, "data_store_id", "DATA_STORE_ID"
            )
            values["search_engine_id"] = get_from_dict_or_env(
                values, "search_engine_id", "SEARCH_ENGINE_ID"
            )
        except Exception:
            pass

        return values

    @property
    def client_options(self) -> "ClientOptions":
        from google.api_core.client_options import ClientOptions

        return ClientOptions(
            api_endpoint=(
                f"{self.location_id}-discoveryengine.googleapis.com"
                if self.location_id != "global"
                else None
            )
        )

    def _convert_structured_search_response(
        self, results: Sequence[SearchResult]
    ) -> List[Document]:
        """将搜索结果序列转换为LangChain文档列表。"""
        import json

        from google.protobuf.json_format import MessageToDict

        documents: List[Document] = []

        for result in results:
            document_dict = MessageToDict(
                result.document._pb, preserving_proto_field_name=True
            )

            documents.append(
                Document(
                    page_content=json.dumps(document_dict.get("struct_data", {})),
                    metadata={"id": document_dict["id"], "name": document_dict["name"]},
                )
            )

        return documents

    def _convert_unstructured_search_response(
        self, results: Sequence[SearchResult], chunk_type: str
    ) -> List[Document]:
        """将搜索结果序列转换为LangChain文档列表。"""
        from google.protobuf.json_format import MessageToDict

        documents: List[Document] = []

        for result in results:
            document_dict = MessageToDict(
                result.document._pb, preserving_proto_field_name=True
            )
            derived_struct_data = document_dict.get("derived_struct_data")
            if not derived_struct_data:
                continue

            doc_metadata = document_dict.get("struct_data", {})
            doc_metadata["id"] = document_dict["id"]

            if chunk_type not in derived_struct_data:
                continue

            for chunk in derived_struct_data[chunk_type]:
                chunk_metadata = doc_metadata.copy()
                chunk_metadata["source"] = derived_struct_data.get("link", "")

                if chunk_type == "extractive_answers":
                    chunk_metadata["source"] += f":{chunk.get('pageNumber', '')}"

                documents.append(
                    Document(
                        page_content=chunk.get("content", ""), metadata=chunk_metadata
                    )
                )

        return documents

    def _convert_website_search_response(
        self, results: Sequence[SearchResult], chunk_type: str
    ) -> List[Document]:
        """将搜索结果序列转换为LangChain文档列表。"""
        from google.protobuf.json_format import MessageToDict

        documents: List[Document] = []

        for result in results:
            document_dict = MessageToDict(
                result.document._pb, preserving_proto_field_name=True
            )
            derived_struct_data = document_dict.get("derived_struct_data")
            if not derived_struct_data:
                continue

            doc_metadata = document_dict.get("struct_data", {})
            doc_metadata["id"] = document_dict["id"]
            doc_metadata["source"] = derived_struct_data.get("link", "")

            if chunk_type not in derived_struct_data:
                continue

            text_field = "snippet" if chunk_type == "snippets" else "content"

            for chunk in derived_struct_data[chunk_type]:
                documents.append(
                    Document(
                        page_content=chunk.get(text_field, ""), metadata=doc_metadata
                    )
                )

        if not documents:
            print(f"No {chunk_type} could be found.")  # noqa: T201
            if chunk_type == "extractive_answers":
                print(  # noqa: T201
                    "Make sure that your data store is using Advanced Website "
                    "Indexing.\n"
                    "https://cloud.google.com/generative-ai-app-builder/docs/about-advanced-features#advanced-website-indexing"  # noqa: E501
                )

        return documents


[docs]@deprecated( since="0.0.33", removal="0.3.0", alternative_import="langchain_google_community.VertexAISearchRetriever", ) class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetriever): """`Google Vertex AI Search`检索器。 有关Vertex AI Search概念和配置参数的详细说明,请参阅产品文档。 https://cloud.google.com/generative-ai-app-builder/docs/enterprise-search-introduction""" filter: Optional[str] = None """过滤表达式。""" get_extractive_answers: bool = False """如果为True则返回抽取式答案,否则返回抽取式段落或片段。""" # noqa: E501 max_documents: int = Field(default=5, ge=1, le=100) """返回的最大文档数量。""" max_extractive_answer_count: int = Field(default=1, ge=1, le=5) """每个搜索结果中返回的最大摘录答案数量。 每个搜索结果最多返回5个答案。""" max_extractive_segment_count: int = Field(default=1, ge=1, le=1) """每个搜索结果中返回的最大摘录段数。 目前每个搜索结果将返回一个段落。""" query_expansion_condition: int = Field(default=1, ge=0, le=2) """用于确定查询扩展应在哪些条件下发生的规范。     0 - 未指定的查询扩展条件。在这种情况下,服务器行为默认为禁用     1 - 禁用查询扩展。仅使用精确的搜索查询,即使SearchResponse.total_size为零。     2 - 由Search API构建的自动查询扩展。""" spell_correction_mode: int = Field(default=2, ge=0, le=2) """确定查询扩展应在哪些条件下发生的规范。     0 - 未指定的拼写更正模式。 在这种情况下,服务器行为默认为自动。     1 - 仅建议。 如果有任何拼写建议,则搜索API将尝试找到拼写建议,并放入 `SearchResponse.corrected_query` 中。         拼写建议不会用作搜索查询。     2 - 搜索API构建的自动拼写更正。         如果找到,搜索将基于更正后的查询进行。""" _client: SearchServiceClient _serving_config: str class Config: """此pydantic对象的配置。""" extra = Extra.ignore arbitrary_types_allowed = True underscore_attrs_are_private = True def __init__(self, **kwargs: Any) -> None: """初始化私有字段。""" try: from google.cloud.discoveryengine_v1beta import SearchServiceClient except ImportError as exc: raise ImportError( "google.cloud.discoveryengine is not installed." "Please install it with pip install google-cloud-discoveryengine" ) from exc super().__init__(**kwargs) # For more information, refer to: # https://cloud.google.com/generative-ai-app-builder/docs/locations#specify_a_multi-region_for_your_data_store self._client = SearchServiceClient( credentials=self.credentials, client_options=self.client_options, client_info=get_client_info(module="vertex-ai-search"), ) if self.engine_data_type == 3 and not self.search_engine_id: raise ValueError( "search_engine_id must be specified for blended search apps." ) if self.search_engine_id: self._serving_config = f"projects/{self.project_id}/locations/{self.location_id}/collections/default_collection/engines/{self.search_engine_id}/servingConfigs/default_config" # noqa: E501 elif self.data_store_id: self._serving_config = self._client.serving_config_path( project=self.project_id, location=self.location_id, data_store=self.data_store_id, serving_config=self.serving_config_id, ) else: raise ValueError( "Either data_store_id or search_engine_id must be specified." ) def _create_search_request(self, query: str) -> SearchRequest: """准备一个SearchRequest对象。""" from google.cloud.discoveryengine_v1beta import SearchRequest query_expansion_spec = SearchRequest.QueryExpansionSpec( condition=self.query_expansion_condition, ) spell_correction_spec = SearchRequest.SpellCorrectionSpec( mode=self.spell_correction_mode ) if self.engine_data_type == 0: if self.get_extractive_answers: extractive_content_spec = ( SearchRequest.ContentSearchSpec.ExtractiveContentSpec( max_extractive_answer_count=self.max_extractive_answer_count, ) ) else: extractive_content_spec = ( SearchRequest.ContentSearchSpec.ExtractiveContentSpec( max_extractive_segment_count=self.max_extractive_segment_count, ) ) content_search_spec = SearchRequest.ContentSearchSpec( extractive_content_spec=extractive_content_spec ) elif self.engine_data_type == 1: content_search_spec = None elif self.engine_data_type in (2, 3): content_search_spec = SearchRequest.ContentSearchSpec( extractive_content_spec=SearchRequest.ContentSearchSpec.ExtractiveContentSpec( max_extractive_answer_count=self.max_extractive_answer_count, ), snippet_spec=SearchRequest.ContentSearchSpec.SnippetSpec( return_snippet=True ), ) else: raise NotImplementedError( "Only data store type 0 (Unstructured), 1 (Structured)," "2 (Website), or 3 (Blended) are supported currently." + f" Got {self.engine_data_type}" ) return SearchRequest( query=query, filter=self.filter, serving_config=self._serving_config, page_size=self.max_documents, content_search_spec=content_search_spec, query_expansion_spec=query_expansion_spec, spell_correction_spec=spell_correction_spec, ) def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: """获取与查询相关的文档。""" return self.get_relevant_documents_with_response(query)[0]
[docs] def get_relevant_documents_with_response( self, query: str ) -> Tuple[List[Document], Any]: from google.api_core.exceptions import InvalidArgument search_request = self._create_search_request(query) try: response = self._client.search(search_request) except InvalidArgument as exc: raise type(exc)( exc.message + " This might be due to engine_data_type not set correctly." ) if self.engine_data_type == 0: chunk_type = ( "extractive_answers" if self.get_extractive_answers else "extractive_segments" ) documents = self._convert_unstructured_search_response( response.results, chunk_type ) elif self.engine_data_type == 1: documents = self._convert_structured_search_response(response.results) elif self.engine_data_type in (2, 3): chunk_type = ( "extractive_answers" if self.get_extractive_answers else "snippets" ) documents = self._convert_website_search_response( response.results, chunk_type ) else: raise NotImplementedError( "Only data store type 0 (Unstructured), 1 (Structured)," "2 (Website), or 3 (Blended) are supported currently." + f" Got {self.engine_data_type}" ) return documents, response
[docs]@deprecated( since="0.0.33", removal="0.3.0", alternative_import="langchain_google_community.VertexAIMultiTurnSearchRetriever", ) class GoogleVertexAIMultiTurnSearchRetriever( BaseRetriever, _BaseGoogleVertexAISearchRetriever ): """`Google Vertex AI Search` 是用于多轮对话的检索器。""" conversation_id: str = "-" """顶点 AI 搜索对话 ID.""" _client: ConversationalSearchServiceClient _serving_config: str class Config: """此pydantic对象的配置。""" extra = Extra.ignore arbitrary_types_allowed = True underscore_attrs_are_private = True def __init__(self, **kwargs: Any): super().__init__(**kwargs) from google.cloud.discoveryengine_v1beta import ( ConversationalSearchServiceClient, ) self._client = ConversationalSearchServiceClient( credentials=self.credentials, client_options=self.client_options, client_info=get_client_info(module="vertex-ai-search"), ) if not self.data_store_id: raise ValueError("data_store_id is required for MultiTurnSearchRetriever.") self._serving_config = self._client.serving_config_path( project=self.project_id, location=self.location_id, data_store=self.data_store_id, serving_config=self.serving_config_id, ) if self.engine_data_type == 1 or self.engine_data_type == 3: raise NotImplementedError( "Data store type 1 (Structured) and 3 (Blended)" "is not currently supported for multi-turn search." + f" Got {self.engine_data_type}" ) def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: """获取与查询相关的文档。""" from google.cloud.discoveryengine_v1beta import ( ConverseConversationRequest, TextInput, ) request = ConverseConversationRequest( name=self._client.conversation_path( self.project_id, self.location_id, self.data_store_id, self.conversation_id, ), serving_config=self._serving_config, query=TextInput(input=query), ) response = self._client.converse_conversation(request) if self.engine_data_type == 2: return self._convert_website_search_response( response.search_results, "extractive_answers" ) return self._convert_unstructured_search_response( response.search_results, "extractive_answers" )
[docs]class GoogleCloudEnterpriseSearchRetriever(GoogleVertexAISearchRetriever): """`Google Vertex Search API` 是为了向后兼容而设置的检索器别名。 已弃用:请改用`GoogleVertexAISearchRetriever`。""" def __init__(self, **data: Any): import warnings warnings.warn( "GoogleCloudEnterpriseSearchRetriever is deprecated, use GoogleVertexAISearchRetriever", # noqa: E501 DeprecationWarning, ) super().__init__(**data)