Source code for langchain_community.chains.graph_qa.neptune_cypher

from __future__ import annotations

import re
from typing import Any, Dict, List, Optional

from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.prompt_selector import ConditionalPromptSelector
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.pydantic_v1 import Field

from langchain_community.chains.graph_qa.prompts import (
    CYPHER_QA_PROMPT,
    NEPTUNE_OPENCYPHER_GENERATION_PROMPT,
    NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT,
)
from langchain_community.graphs import BaseNeptuneGraph

INTERMEDIATE_STEPS_KEY = "intermediate_steps"


[docs]def trim_query(query: str) -> str: """将查询修剪为仅包含Cypher关键字。""" keywords = ( "CALL", "CREATE", "DELETE", "DETACH", "LIMIT", "MATCH", "MERGE", "OPTIONAL", "ORDER", "REMOVE", "RETURN", "SET", "SKIP", "UNWIND", "WITH", "WHERE", "//", ) lines = query.split("\n") new_query = "" for line in lines: if line.strip().upper().startswith(keywords): new_query += line + "\n" return new_query
[docs]def extract_cypher(text: str) -> str: """使用正则表达式从文本中提取Cypher代码。""" # The pattern to find Cypher code enclosed in triple backticks pattern = r"```(.*?)```" # Find all matches in the input text matches = re.findall(pattern, text, re.DOTALL) return matches[0] if matches else text
[docs]def use_simple_prompt(llm: BaseLanguageModel) -> bool: """决定是否使用简单提示""" if llm._llm_type and "anthropic" in llm._llm_type: # type: ignore return True # Bedrock anthropic if hasattr(llm, "model_id") and "anthropic" in llm.model_id: # type: ignore return True return False
PROMPT_SELECTOR = ConditionalPromptSelector( default_prompt=NEPTUNE_OPENCYPHER_GENERATION_PROMPT, conditionals=[(use_simple_prompt, NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT)], )
[docs]class NeptuneOpenCypherQAChain(Chain): """针对 Neptune 图进行问答的链条,通过生成 openCypher 语句。 *安全提示*:确保数据库连接使用的凭据范围狭窄,仅包括必要的权限。如果未能这样做,可能会导致数据损坏或丢失,因为调用代码可能会尝试执行会导致删除、变异数据(如果适当提示)或读取敏感数据(如果数据库中存在此类数据)的命令。防范这些负面结果的最佳方法是(视情况)限制授予此工具使用的凭据的权限。 有关更多信息,请参见 https://python.langchain.com/docs/security。 示例: .. code-block:: python chain = NeptuneOpenCypherQAChain.from_llm( llm=llm, graph=graph ) response = chain.run(query)""" graph: BaseNeptuneGraph = Field(exclude=True) cypher_generation_chain: LLMChain qa_chain: LLMChain input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: top_k: int = 10 return_intermediate_steps: bool = False """是否返回中间步骤以及最终答案。""" return_direct: bool = False """是否直接返回查询图形的结果。""" extra_instructions: Optional[str] = None """在附加到查询生成提示的额外说明。""" @property def input_keys(self) -> List[str]: """返回输入的键。 :元数据 私有: """ return [self.input_key] @property def output_keys(self) -> List[str]: """返回输出键。 :元数据 私有: """ _output_keys = [self.output_key] return _output_keys
[docs] @classmethod def from_llm( cls, llm: BaseLanguageModel, *, qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT, cypher_prompt: Optional[BasePromptTemplate] = None, extra_instructions: Optional[str] = None, **kwargs: Any, ) -> NeptuneOpenCypherQAChain: """从LLM初始化。""" qa_chain = LLMChain(llm=llm, prompt=qa_prompt) _cypher_prompt = cypher_prompt or PROMPT_SELECTOR.get_prompt(llm) cypher_generation_chain = LLMChain(llm=llm, prompt=_cypher_prompt) return cls( qa_chain=qa_chain, cypher_generation_chain=cypher_generation_chain, extra_instructions=extra_instructions, **kwargs, )
def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: """生成Cypher语句,使用它在数据库中查找并回答问题。""" _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() callbacks = _run_manager.get_child() question = inputs[self.input_key] intermediate_steps: List = [] generated_cypher = self.cypher_generation_chain.run( { "question": question, "schema": self.graph.get_schema, "extra_instructions": self.extra_instructions or "", }, callbacks=callbacks, ) # Extract Cypher code if it is wrapped in backticks generated_cypher = extract_cypher(generated_cypher) generated_cypher = trim_query(generated_cypher) _run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose) _run_manager.on_text( generated_cypher, color="green", end="\n", verbose=self.verbose ) intermediate_steps.append({"query": generated_cypher}) context = self.graph.query(generated_cypher) if self.return_direct: final_result = context else: _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) _run_manager.on_text( str(context), color="green", end="\n", verbose=self.verbose ) intermediate_steps.append({"context": context}) result = self.qa_chain( {"question": question, "context": context}, callbacks=callbacks, ) final_result = result[self.qa_chain.output_key] chain_result: Dict[str, Any] = {self.output_key: final_result} if self.return_intermediate_steps: chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps return chain_result