Source code for langchain_community.chains.graph_qa.kuzu

"""在图上的问答。"""
from __future__ import annotations

import re
from typing import Any, Dict, List, Optional

from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field

from langchain_community.chains.graph_qa.prompts import (
    CYPHER_QA_PROMPT,
    KUZU_GENERATION_PROMPT,
)
from langchain_community.graphs.kuzu_graph import KuzuGraph


[docs]def remove_prefix(text: str, prefix: str) -> str: """从文本中删除前缀。 参数: text:要从中删除前缀的文本。 prefix:要从文本中删除的前缀。 返回: 删除前缀后的文本。 """ if text.startswith(prefix): return text[len(prefix) :] return text
[docs]def extract_cypher(text: str) -> str: """从文本中提取Cypher代码。 参数: text:要从中提取Cypher代码的文本。 返回: 从文本中提取的Cypher代码。 """ # The pattern to find Cypher code enclosed in triple backticks pattern = r"```(.*?)```" # Find all matches in the input text matches = re.findall(pattern, text, re.DOTALL) return matches[0] if matches else text
[docs]class KuzuQAChain(Chain): """针对Kùzu生成Cypher语句进行图形问答。 *安全提示*: 确保数据库连接使用的凭据范围狭窄,仅包括必要的权限。 如果未能这样做,可能会导致数据损坏或丢失,因为调用代码可能会尝试执行会导致删除、变异数据(如果适当提示)或读取敏感数据(如果数据库中存在此类数据)的命令。 防范这些负面结果的最佳方法是(根据需要)限制授予此工具使用的凭据的权限。 有关更多信息,请参阅 https://python.langchain.com/docs/security。""" graph: KuzuGraph = Field(exclude=True) cypher_generation_chain: LLMChain qa_chain: LLMChain input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: @property def input_keys(self) -> List[str]: """返回输入的键。 :元数据 私有: """ return [self.input_key] @property def output_keys(self) -> List[str]: """返回输出键。 :元数据 私有: """ _output_keys = [self.output_key] return _output_keys
[docs] @classmethod def from_llm( cls, llm: Optional[BaseLanguageModel] = None, *, qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT, cypher_prompt: BasePromptTemplate = KUZU_GENERATION_PROMPT, cypher_llm: Optional[BaseLanguageModel] = None, qa_llm: Optional[BaseLanguageModel] = None, **kwargs: Any, ) -> KuzuQAChain: """从LLM初始化。""" if not cypher_llm and not llm: raise ValueError("Either `llm` or `cypher_llm` parameters must be provided") if not qa_llm and not llm: raise ValueError( "Either `llm` or `qa_llm` parameters must be provided along with" " `cypher_llm`" ) if cypher_llm and qa_llm and llm: raise ValueError( "You can specify up to two of 'cypher_llm', 'qa_llm'" ", and 'llm', but not all three simultaneously." ) qa_chain = LLMChain( llm=qa_llm or llm, # type: ignore[arg-type] prompt=qa_prompt, ) cypher_generation_chain = LLMChain( llm=cypher_llm or llm, # type: ignore[arg-type] prompt=cypher_prompt, ) return cls( qa_chain=qa_chain, cypher_generation_chain=cypher_generation_chain, **kwargs, )
def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, str]: """生成Cypher语句,使用它在数据库中查找并回答问题。""" _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() callbacks = _run_manager.get_child() question = inputs[self.input_key] generated_cypher = self.cypher_generation_chain.run( {"question": question, "schema": self.graph.get_schema}, callbacks=callbacks ) # Extract Cypher code if it is wrapped in triple backticks # with the language marker "cypher" generated_cypher = remove_prefix(extract_cypher(generated_cypher), "cypher") _run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose) _run_manager.on_text( generated_cypher, color="green", end="\n", verbose=self.verbose ) context = self.graph.query(generated_cypher) _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) _run_manager.on_text( str(context), color="green", end="\n", verbose=self.verbose ) result = self.qa_chain( {"question": question, "context": context}, callbacks=callbacks, ) return {self.output_key: result[self.qa_chain.output_key]}