"""Retriever的Google Vertex AI Search包装器。"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
from langchain_core.retrievers import BaseRetriever
from langchain_core.utils import get_from_dict_or_env
from langchain_community.utilities.vertexai import get_client_info
if TYPE_CHECKING:
from google.api_core.client_options import ClientOptions
from google.cloud.discoveryengine_v1beta import (
ConversationalSearchServiceClient,
SearchRequest,
SearchResult,
SearchServiceClient,
)
class _BaseGoogleVertexAISearchRetriever(BaseModel):
project_id: str
"""谷歌云项目ID。"""
data_store_id: Optional[str] = None
"""Vertex AI搜索数据存储ID。"""
search_engine_id: Optional[str] = None
"""Vertex AI搜索应用程序ID。"""
location_id: str = "global"
"""Vertex AI搜索数据存储位置。"""
serving_config_id: str = "default_config"
"""Vertex AI搜索服务配置ID。"""
credentials: Any = None
"""默认的自定义凭据(google.auth.credentials.Credentials)用于进行API调用时使用。如果未提供,则将从环境中确定凭据。"""
engine_data_type: int = Field(default=0, ge=0, le=3)
"""定义了Vertex AI搜索应用数据类型
0 - 非结构化数据
1 - 结构化数据
2 - 网站数据
3 - 混合搜索"""
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""验证环境。"""
try:
from google.cloud import discoveryengine_v1beta # noqa: F401
except ImportError as exc:
raise ImportError(
"google.cloud.discoveryengine is not installed."
"Please install it with pip install "
"google-cloud-discoveryengine>=0.11.10"
) from exc
try:
from google.api_core.exceptions import InvalidArgument # noqa: F401
except ImportError as exc:
raise ImportError(
"google.api_core.exceptions is not installed. "
"Please install it with pip install google-api-core"
) from exc
values["project_id"] = get_from_dict_or_env(values, "project_id", "PROJECT_ID")
try:
values["data_store_id"] = get_from_dict_or_env(
values, "data_store_id", "DATA_STORE_ID"
)
values["search_engine_id"] = get_from_dict_or_env(
values, "search_engine_id", "SEARCH_ENGINE_ID"
)
except Exception:
pass
return values
@property
def client_options(self) -> "ClientOptions":
from google.api_core.client_options import ClientOptions
return ClientOptions(
api_endpoint=(
f"{self.location_id}-discoveryengine.googleapis.com"
if self.location_id != "global"
else None
)
)
def _convert_structured_search_response(
self, results: Sequence[SearchResult]
) -> List[Document]:
"""将搜索结果序列转换为LangChain文档列表。"""
import json
from google.protobuf.json_format import MessageToDict
documents: List[Document] = []
for result in results:
document_dict = MessageToDict(
result.document._pb, preserving_proto_field_name=True
)
documents.append(
Document(
page_content=json.dumps(document_dict.get("struct_data", {})),
metadata={"id": document_dict["id"], "name": document_dict["name"]},
)
)
return documents
def _convert_unstructured_search_response(
self, results: Sequence[SearchResult], chunk_type: str
) -> List[Document]:
"""将搜索结果序列转换为LangChain文档列表。"""
from google.protobuf.json_format import MessageToDict
documents: List[Document] = []
for result in results:
document_dict = MessageToDict(
result.document._pb, preserving_proto_field_name=True
)
derived_struct_data = document_dict.get("derived_struct_data")
if not derived_struct_data:
continue
doc_metadata = document_dict.get("struct_data", {})
doc_metadata["id"] = document_dict["id"]
if chunk_type not in derived_struct_data:
continue
for chunk in derived_struct_data[chunk_type]:
chunk_metadata = doc_metadata.copy()
chunk_metadata["source"] = derived_struct_data.get("link", "")
if chunk_type == "extractive_answers":
chunk_metadata["source"] += f":{chunk.get('pageNumber', '')}"
documents.append(
Document(
page_content=chunk.get("content", ""), metadata=chunk_metadata
)
)
return documents
def _convert_website_search_response(
self, results: Sequence[SearchResult], chunk_type: str
) -> List[Document]:
"""将搜索结果序列转换为LangChain文档列表。"""
from google.protobuf.json_format import MessageToDict
documents: List[Document] = []
for result in results:
document_dict = MessageToDict(
result.document._pb, preserving_proto_field_name=True
)
derived_struct_data = document_dict.get("derived_struct_data")
if not derived_struct_data:
continue
doc_metadata = document_dict.get("struct_data", {})
doc_metadata["id"] = document_dict["id"]
doc_metadata["source"] = derived_struct_data.get("link", "")
if chunk_type not in derived_struct_data:
continue
text_field = "snippet" if chunk_type == "snippets" else "content"
for chunk in derived_struct_data[chunk_type]:
documents.append(
Document(
page_content=chunk.get(text_field, ""), metadata=doc_metadata
)
)
if not documents:
print(f"No {chunk_type} could be found.") # noqa: T201
if chunk_type == "extractive_answers":
print( # noqa: T201
"Make sure that your data store is using Advanced Website "
"Indexing.\n"
"https://cloud.google.com/generative-ai-app-builder/docs/about-advanced-features#advanced-website-indexing" # noqa: E501
)
return documents
[docs]@deprecated(
since="0.0.33",
removal="0.3.0",
alternative_import="langchain_google_community.VertexAISearchRetriever",
)
class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetriever):
"""`Google Vertex AI Search`检索器。
有关Vertex AI Search概念和配置参数的详细说明,请参阅产品文档。
https://cloud.google.com/generative-ai-app-builder/docs/enterprise-search-introduction"""
filter: Optional[str] = None
"""过滤表达式。"""
get_extractive_answers: bool = False
"""如果为True则返回抽取式答案,否则返回抽取式段落或片段。""" # noqa: E501
max_documents: int = Field(default=5, ge=1, le=100)
"""返回的最大文档数量。"""
max_extractive_answer_count: int = Field(default=1, ge=1, le=5)
"""每个搜索结果中返回的最大摘录答案数量。
每个搜索结果最多返回5个答案。"""
max_extractive_segment_count: int = Field(default=1, ge=1, le=1)
"""每个搜索结果中返回的最大摘录段数。
目前每个搜索结果将返回一个段落。"""
query_expansion_condition: int = Field(default=1, ge=0, le=2)
"""用于确定查询扩展应在哪些条件下发生的规范。
0 - 未指定的查询扩展条件。在这种情况下,服务器行为默认为禁用
1 - 禁用查询扩展。仅使用精确的搜索查询,即使SearchResponse.total_size为零。
2 - 由Search API构建的自动查询扩展。"""
spell_correction_mode: int = Field(default=2, ge=0, le=2)
"""确定查询扩展应在哪些条件下发生的规范。
0 - 未指定的拼写更正模式。 在这种情况下,服务器行为默认为自动。
1 - 仅建议。 如果有任何拼写建议,则搜索API将尝试找到拼写建议,并放入 `SearchResponse.corrected_query` 中。
拼写建议不会用作搜索查询。
2 - 搜索API构建的自动拼写更正。
如果找到,搜索将基于更正后的查询进行。"""
_client: SearchServiceClient
_serving_config: str
class Config:
"""此pydantic对象的配置。"""
extra = Extra.ignore
arbitrary_types_allowed = True
underscore_attrs_are_private = True
def __init__(self, **kwargs: Any) -> None:
"""初始化私有字段。"""
try:
from google.cloud.discoveryengine_v1beta import SearchServiceClient
except ImportError as exc:
raise ImportError(
"google.cloud.discoveryengine is not installed."
"Please install it with pip install google-cloud-discoveryengine"
) from exc
super().__init__(**kwargs)
# For more information, refer to:
# https://cloud.google.com/generative-ai-app-builder/docs/locations#specify_a_multi-region_for_your_data_store
self._client = SearchServiceClient(
credentials=self.credentials,
client_options=self.client_options,
client_info=get_client_info(module="vertex-ai-search"),
)
if self.engine_data_type == 3 and not self.search_engine_id:
raise ValueError(
"search_engine_id must be specified for blended search apps."
)
if self.search_engine_id:
self._serving_config = f"projects/{self.project_id}/locations/{self.location_id}/collections/default_collection/engines/{self.search_engine_id}/servingConfigs/default_config" # noqa: E501
elif self.data_store_id:
self._serving_config = self._client.serving_config_path(
project=self.project_id,
location=self.location_id,
data_store=self.data_store_id,
serving_config=self.serving_config_id,
)
else:
raise ValueError(
"Either data_store_id or search_engine_id must be specified."
)
def _create_search_request(self, query: str) -> SearchRequest:
"""准备一个SearchRequest对象。"""
from google.cloud.discoveryengine_v1beta import SearchRequest
query_expansion_spec = SearchRequest.QueryExpansionSpec(
condition=self.query_expansion_condition,
)
spell_correction_spec = SearchRequest.SpellCorrectionSpec(
mode=self.spell_correction_mode
)
if self.engine_data_type == 0:
if self.get_extractive_answers:
extractive_content_spec = (
SearchRequest.ContentSearchSpec.ExtractiveContentSpec(
max_extractive_answer_count=self.max_extractive_answer_count,
)
)
else:
extractive_content_spec = (
SearchRequest.ContentSearchSpec.ExtractiveContentSpec(
max_extractive_segment_count=self.max_extractive_segment_count,
)
)
content_search_spec = SearchRequest.ContentSearchSpec(
extractive_content_spec=extractive_content_spec
)
elif self.engine_data_type == 1:
content_search_spec = None
elif self.engine_data_type in (2, 3):
content_search_spec = SearchRequest.ContentSearchSpec(
extractive_content_spec=SearchRequest.ContentSearchSpec.ExtractiveContentSpec(
max_extractive_answer_count=self.max_extractive_answer_count,
),
snippet_spec=SearchRequest.ContentSearchSpec.SnippetSpec(
return_snippet=True
),
)
else:
raise NotImplementedError(
"Only data store type 0 (Unstructured), 1 (Structured),"
"2 (Website), or 3 (Blended) are supported currently."
+ f" Got {self.engine_data_type}"
)
return SearchRequest(
query=query,
filter=self.filter,
serving_config=self._serving_config,
page_size=self.max_documents,
content_search_spec=content_search_spec,
query_expansion_spec=query_expansion_spec,
spell_correction_spec=spell_correction_spec,
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""获取与查询相关的文档。"""
return self.get_relevant_documents_with_response(query)[0]
[docs] def get_relevant_documents_with_response(
self, query: str
) -> Tuple[List[Document], Any]:
from google.api_core.exceptions import InvalidArgument
search_request = self._create_search_request(query)
try:
response = self._client.search(search_request)
except InvalidArgument as exc:
raise type(exc)(
exc.message
+ " This might be due to engine_data_type not set correctly."
)
if self.engine_data_type == 0:
chunk_type = (
"extractive_answers"
if self.get_extractive_answers
else "extractive_segments"
)
documents = self._convert_unstructured_search_response(
response.results, chunk_type
)
elif self.engine_data_type == 1:
documents = self._convert_structured_search_response(response.results)
elif self.engine_data_type in (2, 3):
chunk_type = (
"extractive_answers" if self.get_extractive_answers else "snippets"
)
documents = self._convert_website_search_response(
response.results, chunk_type
)
else:
raise NotImplementedError(
"Only data store type 0 (Unstructured), 1 (Structured),"
"2 (Website), or 3 (Blended) are supported currently."
+ f" Got {self.engine_data_type}"
)
return documents, response
[docs]@deprecated(
since="0.0.33",
removal="0.3.0",
alternative_import="langchain_google_community.VertexAIMultiTurnSearchRetriever",
)
class GoogleVertexAIMultiTurnSearchRetriever(
BaseRetriever, _BaseGoogleVertexAISearchRetriever
):
"""`Google Vertex AI Search` 是用于多轮对话的检索器。"""
conversation_id: str = "-"
"""顶点 AI 搜索对话 ID."""
_client: ConversationalSearchServiceClient
_serving_config: str
class Config:
"""此pydantic对象的配置。"""
extra = Extra.ignore
arbitrary_types_allowed = True
underscore_attrs_are_private = True
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
from google.cloud.discoveryengine_v1beta import (
ConversationalSearchServiceClient,
)
self._client = ConversationalSearchServiceClient(
credentials=self.credentials,
client_options=self.client_options,
client_info=get_client_info(module="vertex-ai-search"),
)
if not self.data_store_id:
raise ValueError("data_store_id is required for MultiTurnSearchRetriever.")
self._serving_config = self._client.serving_config_path(
project=self.project_id,
location=self.location_id,
data_store=self.data_store_id,
serving_config=self.serving_config_id,
)
if self.engine_data_type == 1 or self.engine_data_type == 3:
raise NotImplementedError(
"Data store type 1 (Structured) and 3 (Blended)"
"is not currently supported for multi-turn search."
+ f" Got {self.engine_data_type}"
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""获取与查询相关的文档。"""
from google.cloud.discoveryengine_v1beta import (
ConverseConversationRequest,
TextInput,
)
request = ConverseConversationRequest(
name=self._client.conversation_path(
self.project_id,
self.location_id,
self.data_store_id,
self.conversation_id,
),
serving_config=self._serving_config,
query=TextInput(input=query),
)
response = self._client.converse_conversation(request)
if self.engine_data_type == 2:
return self._convert_website_search_response(
response.search_results, "extractive_answers"
)
return self._convert_unstructured_search_response(
response.search_results, "extractive_answers"
)
[docs]class GoogleCloudEnterpriseSearchRetriever(GoogleVertexAISearchRetriever):
"""`Google Vertex Search API` 是为了向后兼容而设置的检索器别名。
已弃用:请改用`GoogleVertexAISearchRetriever`。"""
def __init__(self, **data: Any):
import warnings
warnings.warn(
"GoogleCloudEnterpriseSearchRetriever is deprecated, use GoogleVertexAISearchRetriever", # noqa: E501
DeprecationWarning,
)
super().__init__(**data)