"""与Elasticsearch数据库交互的链。"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseLLMOutputParser
from langchain_core.output_parsers.json import SimpleJsonOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Extra, root_validator
from langchain.chains.base import Chain
from langchain.chains.elasticsearch_database.prompts import ANSWER_PROMPT, DSL_PROMPT
from langchain.chains.llm import LLMChain
if TYPE_CHECKING:
from elasticsearch import Elasticsearch
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
[docs]class ElasticsearchDatabaseChain(Chain):
"""与Elasticsearch数据库交互的链。
示例:
.. code-block:: python
from langchain.chains import ElasticsearchDatabaseChain
from langchain_community.llms import OpenAI
from elasticsearch import Elasticsearch
database = Elasticsearch("http://localhost:9200")
db_chain = ElasticsearchDatabaseChain.from_llm(OpenAI(), database)"""
query_chain: LLMChain
"""用于创建ES查询的链。"""
answer_chain: LLMChain
"""回答用户问题的链条。"""
database: Any
"""将要连接的Elasticsearch数据库的类型为elasticsearch.Elasticsearch。"""
top_k: int = 10
"""查询返回的结果数量"""
ignore_indices: Optional[List[str]] = None
include_indices: Optional[List[str]] = None
input_key: str = "question" #: :meta private:
output_key: str = "result" #: :meta private:
sample_documents_in_index_info: int = 3
return_intermediate_steps: bool = False
"""是否返回中间步骤以及最终答案。"""
class Config:
"""这个pydantic对象的配置。"""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator()
def validate_indices(cls, values: dict) -> dict:
if values["include_indices"] and values["ignore_indices"]:
raise ValueError(
"Cannot specify both 'include_indices' and 'ignore_indices'."
)
return values
@property
def input_keys(self) -> List[str]:
"""返回单个输入键。
:元数据 私有:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""返回单个输出键。
:元数据 私有:
"""
if not self.return_intermediate_steps:
return [self.output_key]
else:
return [self.output_key, INTERMEDIATE_STEPS_KEY]
def _list_indices(self) -> List[str]:
all_indices = [
index["index"] for index in self.database.cat.indices(format="json")
]
if self.include_indices:
all_indices = [i for i in all_indices if i in self.include_indices]
if self.ignore_indices:
all_indices = [i for i in all_indices if i not in self.ignore_indices]
return all_indices
def _get_indices_infos(self, indices: List[str]) -> str:
mappings = self.database.indices.get_mapping(index=",".join(indices))
if self.sample_documents_in_index_info > 0:
for k, v in mappings.items():
hits = self.database.search(
index=k,
query={"match_all": {}},
size=self.sample_documents_in_index_info,
)["hits"]["hits"]
hits = [str(hit["_source"]) for hit in hits]
mappings[k]["mappings"] = str(v) + "\n\n/*\n" + "\n".join(hits) + "\n*/"
return "\n\n".join(
[
"Mapping for index {}:\n{}".format(index, mappings[index]["mappings"])
for index in mappings
]
)
def _search(self, indices: List[str], query: str) -> str:
result = self.database.search(index=",".join(indices), body=query)
return str(result)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
input_text = f"{inputs[self.input_key]}\nESQuery:"
_run_manager.on_text(input_text, verbose=self.verbose)
indices = self._list_indices()
indices_info = self._get_indices_infos(indices)
query_inputs: dict = {
"input": input_text,
"top_k": str(self.top_k),
"indices_info": indices_info,
"stop": ["\nESResult:"],
}
intermediate_steps: List = []
try:
intermediate_steps.append(query_inputs) # input: es generation
es_cmd = self.query_chain.run(
callbacks=_run_manager.get_child(),
**query_inputs,
)
_run_manager.on_text(es_cmd, color="green", verbose=self.verbose)
intermediate_steps.append(
es_cmd
) # output: elasticsearch dsl generation (no checker)
intermediate_steps.append({"es_cmd": es_cmd}) # input: ES search
result = self._search(indices=indices, query=es_cmd)
intermediate_steps.append(str(result)) # output: ES search
_run_manager.on_text("\nESResult: ", verbose=self.verbose)
_run_manager.on_text(result, color="yellow", verbose=self.verbose)
_run_manager.on_text("\nAnswer:", verbose=self.verbose)
answer_inputs: dict = {"data": result, "input": input_text}
intermediate_steps.append(answer_inputs) # input: final answer
final_result = self.answer_chain.run(
callbacks=_run_manager.get_child(),
**answer_inputs,
)
intermediate_steps.append(final_result) # output: final answer
_run_manager.on_text(final_result, color="green", verbose=self.verbose)
chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
return chain_result
except Exception as exc:
# Append intermediate steps to exception, to aid in logging and later
# improvement of few shot prompt seeds
exc.intermediate_steps = intermediate_steps # type: ignore
raise exc
@property
def _chain_type(self) -> str:
return "elasticsearch_database_chain"
[docs] @classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
database: Elasticsearch,
*,
query_prompt: Optional[BasePromptTemplate] = None,
answer_prompt: Optional[BasePromptTemplate] = None,
query_output_parser: Optional[BaseLLMOutputParser] = None,
**kwargs: Any,
) -> ElasticsearchDatabaseChain:
"""从LLM构建ElasticsearchDatabaseChain的便捷方法。
参数:
llm:要使用的语言模型。
database:Elasticsearch数据库。
query_prompt:用于构建查询的提示。
answer_prompt:给定数据后用于回答用户问题的提示。
query_output_parser:用于解析模型生成的ES查询的输出解析器。默认为SimpleJsonOutputParser。
**kwargs:传递给构造函数的其他参数。
"""
query_prompt = query_prompt or DSL_PROMPT
query_output_parser = query_output_parser or SimpleJsonOutputParser()
query_chain = LLMChain(
llm=llm, prompt=query_prompt, output_parser=query_output_parser
)
answer_prompt = answer_prompt or ANSWER_PROMPT
answer_chain = LLMChain(llm=llm, prompt=answer_prompt)
return cls(
query_chain=query_chain,
answer_chain=answer_chain,
database=database,
**kwargs,
)