Source code for langchain.chains.elasticsearch_database.base

"""与Elasticsearch数据库交互的链。"""
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, List, Optional

from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseLLMOutputParser
from langchain_core.output_parsers.json import SimpleJsonOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Extra, root_validator

from langchain.chains.base import Chain
from langchain.chains.elasticsearch_database.prompts import ANSWER_PROMPT, DSL_PROMPT
from langchain.chains.llm import LLMChain

if TYPE_CHECKING:
    from elasticsearch import Elasticsearch

INTERMEDIATE_STEPS_KEY = "intermediate_steps"


[docs]class ElasticsearchDatabaseChain(Chain): """与Elasticsearch数据库交互的链。 示例: .. code-block:: python from langchain.chains import ElasticsearchDatabaseChain from langchain_community.llms import OpenAI from elasticsearch import Elasticsearch database = Elasticsearch("http://localhost:9200") db_chain = ElasticsearchDatabaseChain.from_llm(OpenAI(), database)""" query_chain: LLMChain """用于创建ES查询的链。""" answer_chain: LLMChain """回答用户问题的链条。""" database: Any """将要连接的Elasticsearch数据库的类型为elasticsearch.Elasticsearch。""" top_k: int = 10 """查询返回的结果数量""" ignore_indices: Optional[List[str]] = None include_indices: Optional[List[str]] = None input_key: str = "question" #: :meta private: output_key: str = "result" #: :meta private: sample_documents_in_index_info: int = 3 return_intermediate_steps: bool = False """是否返回中间步骤以及最终答案。""" class Config: """这个pydantic对象的配置。""" extra = Extra.forbid arbitrary_types_allowed = True @root_validator() def validate_indices(cls, values: dict) -> dict: if values["include_indices"] and values["ignore_indices"]: raise ValueError( "Cannot specify both 'include_indices' and 'ignore_indices'." ) return values @property def input_keys(self) -> List[str]: """返回单个输入键。 :元数据 私有: """ return [self.input_key] @property def output_keys(self) -> List[str]: """返回单个输出键。 :元数据 私有: """ if not self.return_intermediate_steps: return [self.output_key] else: return [self.output_key, INTERMEDIATE_STEPS_KEY] def _list_indices(self) -> List[str]: all_indices = [ index["index"] for index in self.database.cat.indices(format="json") ] if self.include_indices: all_indices = [i for i in all_indices if i in self.include_indices] if self.ignore_indices: all_indices = [i for i in all_indices if i not in self.ignore_indices] return all_indices def _get_indices_infos(self, indices: List[str]) -> str: mappings = self.database.indices.get_mapping(index=",".join(indices)) if self.sample_documents_in_index_info > 0: for k, v in mappings.items(): hits = self.database.search( index=k, query={"match_all": {}}, size=self.sample_documents_in_index_info, )["hits"]["hits"] hits = [str(hit["_source"]) for hit in hits] mappings[k]["mappings"] = str(v) + "\n\n/*\n" + "\n".join(hits) + "\n*/" return "\n\n".join( [ "Mapping for index {}:\n{}".format(index, mappings[index]["mappings"]) for index in mappings ] ) def _search(self, indices: List[str], query: str) -> str: result = self.database.search(index=",".join(indices), body=query) return str(result) def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() input_text = f"{inputs[self.input_key]}\nESQuery:" _run_manager.on_text(input_text, verbose=self.verbose) indices = self._list_indices() indices_info = self._get_indices_infos(indices) query_inputs: dict = { "input": input_text, "top_k": str(self.top_k), "indices_info": indices_info, "stop": ["\nESResult:"], } intermediate_steps: List = [] try: intermediate_steps.append(query_inputs) # input: es generation es_cmd = self.query_chain.run( callbacks=_run_manager.get_child(), **query_inputs, ) _run_manager.on_text(es_cmd, color="green", verbose=self.verbose) intermediate_steps.append( es_cmd ) # output: elasticsearch dsl generation (no checker) intermediate_steps.append({"es_cmd": es_cmd}) # input: ES search result = self._search(indices=indices, query=es_cmd) intermediate_steps.append(str(result)) # output: ES search _run_manager.on_text("\nESResult: ", verbose=self.verbose) _run_manager.on_text(result, color="yellow", verbose=self.verbose) _run_manager.on_text("\nAnswer:", verbose=self.verbose) answer_inputs: dict = {"data": result, "input": input_text} intermediate_steps.append(answer_inputs) # input: final answer final_result = self.answer_chain.run( callbacks=_run_manager.get_child(), **answer_inputs, ) intermediate_steps.append(final_result) # output: final answer _run_manager.on_text(final_result, color="green", verbose=self.verbose) chain_result: Dict[str, Any] = {self.output_key: final_result} if self.return_intermediate_steps: chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps return chain_result except Exception as exc: # Append intermediate steps to exception, to aid in logging and later # improvement of few shot prompt seeds exc.intermediate_steps = intermediate_steps # type: ignore raise exc @property def _chain_type(self) -> str: return "elasticsearch_database_chain"
[docs] @classmethod def from_llm( cls, llm: BaseLanguageModel, database: Elasticsearch, *, query_prompt: Optional[BasePromptTemplate] = None, answer_prompt: Optional[BasePromptTemplate] = None, query_output_parser: Optional[BaseLLMOutputParser] = None, **kwargs: Any, ) -> ElasticsearchDatabaseChain: """从LLM构建ElasticsearchDatabaseChain的便捷方法。 参数: llm:要使用的语言模型。 database:Elasticsearch数据库。 query_prompt:用于构建查询的提示。 answer_prompt:给定数据后用于回答用户问题的提示。 query_output_parser:用于解析模型生成的ES查询的输出解析器。默认为SimpleJsonOutputParser。 **kwargs:传递给构造函数的其他参数。 """ query_prompt = query_prompt or DSL_PROMPT query_output_parser = query_output_parser or SimpleJsonOutputParser() query_chain = LLMChain( llm=llm, prompt=query_prompt, output_parser=query_output_parser ) answer_prompt = answer_prompt or ANSWER_PROMPT answer_chain = LLMChain(llm=llm, prompt=answer_prompt) return cls( query_chain=query_chain, answer_chain=answer_chain, database=database, **kwargs, )