Source code for langchain_community.chains.graph_qa.kuzu
"""在图上的问答。"""
from __future__ import annotations
import re
from typing import Any, Dict, List, Optional
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_community.chains.graph_qa.prompts import (
CYPHER_QA_PROMPT,
KUZU_GENERATION_PROMPT,
)
from langchain_community.graphs.kuzu_graph import KuzuGraph
[docs]def remove_prefix(text: str, prefix: str) -> str:
"""从文本中删除前缀。
参数:
text:要从中删除前缀的文本。
prefix:要从文本中删除的前缀。
返回:
删除前缀后的文本。
"""
if text.startswith(prefix):
return text[len(prefix) :]
return text
[docs]class KuzuQAChain(Chain):
"""针对Kùzu生成Cypher语句进行图形问答。
*安全提示*: 确保数据库连接使用的凭据范围狭窄,仅包括必要的权限。
如果未能这样做,可能会导致数据损坏或丢失,因为调用代码可能会尝试执行会导致删除、变异数据(如果适当提示)或读取敏感数据(如果数据库中存在此类数据)的命令。
防范这些负面结果的最佳方法是(根据需要)限制授予此工具使用的凭据的权限。
有关更多信息,请参阅 https://python.langchain.com/docs/security。"""
graph: KuzuGraph = Field(exclude=True)
cypher_generation_chain: LLMChain
qa_chain: LLMChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""返回输入的键。
:元数据 私有:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""返回输出键。
:元数据 私有:
"""
_output_keys = [self.output_key]
return _output_keys
[docs] @classmethod
def from_llm(
cls,
llm: Optional[BaseLanguageModel] = None,
*,
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
cypher_prompt: BasePromptTemplate = KUZU_GENERATION_PROMPT,
cypher_llm: Optional[BaseLanguageModel] = None,
qa_llm: Optional[BaseLanguageModel] = None,
**kwargs: Any,
) -> KuzuQAChain:
"""从LLM初始化。"""
if not cypher_llm and not llm:
raise ValueError("Either `llm` or `cypher_llm` parameters must be provided")
if not qa_llm and not llm:
raise ValueError(
"Either `llm` or `qa_llm` parameters must be provided along with"
" `cypher_llm`"
)
if cypher_llm and qa_llm and llm:
raise ValueError(
"You can specify up to two of 'cypher_llm', 'qa_llm'"
", and 'llm', but not all three simultaneously."
)
qa_chain = LLMChain(
llm=qa_llm or llm, # type: ignore[arg-type]
prompt=qa_prompt,
)
cypher_generation_chain = LLMChain(
llm=cypher_llm or llm, # type: ignore[arg-type]
prompt=cypher_prompt,
)
return cls(
qa_chain=qa_chain,
cypher_generation_chain=cypher_generation_chain,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""生成Cypher语句,使用它在数据库中查找并回答问题。"""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
question = inputs[self.input_key]
generated_cypher = self.cypher_generation_chain.run(
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
)
# Extract Cypher code if it is wrapped in triple backticks
# with the language marker "cypher"
generated_cypher = remove_prefix(extract_cypher(generated_cypher), "cypher")
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_cypher, color="green", end="\n", verbose=self.verbose
)
context = self.graph.query(generated_cypher)
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(context), color="green", end="\n", verbose=self.verbose
)
result = self.qa_chain(
{"question": question, "context": context},
callbacks=callbacks,
)
return {self.output_key: result[self.qa_chain.output_key]}