Source code for langchain_community.chains.graph_qa.arangodb

"""在图上的问答。"""
from __future__ import annotations

import re
from typing import Any, Dict, List, Optional

from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field

from langchain_community.chains.graph_qa.prompts import (
    AQL_FIX_PROMPT,
    AQL_GENERATION_PROMPT,
    AQL_QA_PROMPT,
)
from langchain_community.graphs.arangodb_graph import ArangoGraph


[docs]class ArangoGraphQAChain(Chain): """用于通过生成AQL语句针对图形进行问答的链。 *安全提示*: 确保数据库连接使用的凭据仅限于包括必要权限。如果未能这样做,可能会导致数据损坏或丢失,因为调用代码可能会尝试执行会导致删除、变异数据(如果适当提示)或读取敏感数据(如果数据库中存在此类数据)的命令。防范这些负面结果的最佳方法是(视情况)限制授予此工具使用的凭据的权限。 有关更多信息,请参见 https://python.langchain.com/docs/security。""" graph: ArangoGraph = Field(exclude=True) aql_generation_chain: LLMChain aql_fix_chain: LLMChain qa_chain: LLMChain input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: # Specifies the maximum number of AQL Query Results to return top_k: int = 10 # Specifies the set of AQL Query Examples that promote few-shot-learning aql_examples: str = "" # Specify whether to return the AQL Query in the output dictionary return_aql_query: bool = False # Specify whether to return the AQL JSON Result in the output dictionary return_aql_result: bool = False # Specify the maximum amount of AQL Generation attempts that should be made max_aql_generation_attempts: int = 3 @property def input_keys(self) -> List[str]: return [self.input_key] @property def output_keys(self) -> List[str]: return [self.output_key] @property def _chain_type(self) -> str: return "graph_aql_chain"
[docs] @classmethod def from_llm( cls, llm: BaseLanguageModel, *, qa_prompt: BasePromptTemplate = AQL_QA_PROMPT, aql_generation_prompt: BasePromptTemplate = AQL_GENERATION_PROMPT, aql_fix_prompt: BasePromptTemplate = AQL_FIX_PROMPT, **kwargs: Any, ) -> ArangoGraphQAChain: """从LLM初始化。""" qa_chain = LLMChain(llm=llm, prompt=qa_prompt) aql_generation_chain = LLMChain(llm=llm, prompt=aql_generation_prompt) aql_fix_chain = LLMChain(llm=llm, prompt=aql_fix_prompt) return cls( qa_chain=qa_chain, aql_generation_chain=aql_generation_chain, aql_fix_chain=aql_fix_chain, **kwargs, )
def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: """从用户输入生成一个AQL语句,使用它从ArangoDB数据库实例中检索响应,并用自然语言回应用户输入。 用户可以修改以下ArangoGraphQAChain类变量: :var top_k: 要返回的AQL查询结果的最大数量 :type top_k: int :var aql_examples: 一组AQL查询示例,传递给AQL生成提示模板以促进少量学习。默认为空字符串。 :type aql_examples: str :var return_aql_query: 是否在输出字典中返回AQL查询。默认为False。 :type return_aql_query: bool :var return_aql_result: 是否在输出字典中返回AQL查询。默认为False。 :type return_aql_result: bool :var max_aql_generation_attempts: 在引发最后一个AQL查询执行错误之前要进行的AQL生成尝试的最大次数。默认为3。 :type max_aql_generation_attempts: int """ _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() callbacks = _run_manager.get_child() user_input = inputs[self.input_key] ######################### # Generate AQL Query # aql_generation_output = self.aql_generation_chain.run( { "adb_schema": self.graph.schema, "aql_examples": self.aql_examples, "user_input": user_input, }, callbacks=callbacks, ) ######################### aql_query = "" aql_error = "" aql_result = None aql_generation_attempt = 1 while ( aql_result is None and aql_generation_attempt < self.max_aql_generation_attempts + 1 ): ##################### # Extract AQL Query # pattern = r"```(?i:aql)?(.*?)```" matches = re.findall(pattern, aql_generation_output, re.DOTALL) if not matches: _run_manager.on_text( "Invalid Response: ", end="\n", verbose=self.verbose ) _run_manager.on_text( aql_generation_output, color="red", end="\n", verbose=self.verbose ) raise ValueError(f"Response is Invalid: {aql_generation_output}") aql_query = matches[0] ##################### _run_manager.on_text( f"AQL Query ({aql_generation_attempt}):", verbose=self.verbose ) _run_manager.on_text( aql_query, color="green", end="\n", verbose=self.verbose ) ##################### # Execute AQL Query # from arango import AQLQueryExecuteError try: aql_result = self.graph.query(aql_query, self.top_k) except AQLQueryExecuteError as e: aql_error = e.error_message _run_manager.on_text( "AQL Query Execution Error: ", end="\n", verbose=self.verbose ) _run_manager.on_text( aql_error, color="yellow", end="\n\n", verbose=self.verbose ) ######################## # Retry AQL Generation # aql_generation_output = self.aql_fix_chain.run( { "adb_schema": self.graph.schema, "aql_query": aql_query, "aql_error": aql_error, }, callbacks=callbacks, ) ######################## ##################### aql_generation_attempt += 1 if aql_result is None: m = f""" Maximum amount of AQL Query Generation attempts reached. Unable to execute the AQL Query due to the following error: {aql_error} """ raise ValueError(m) _run_manager.on_text("AQL Result:", end="\n", verbose=self.verbose) _run_manager.on_text( str(aql_result), color="green", end="\n", verbose=self.verbose ) ######################## # Interpret AQL Result # result = self.qa_chain( { "adb_schema": self.graph.schema, "user_input": user_input, "aql_query": aql_query, "aql_result": aql_result, }, callbacks=callbacks, ) ######################## # Return results # result = {self.output_key: result[self.qa_chain.output_key]} if self.return_aql_query: result["aql_query"] = aql_query if self.return_aql_result: result["aql_result"] = aql_result return result