Source code for langchain_community.chains.graph_qa.cypher

"""在图上的问答。"""
from __future__ import annotations

import re
from typing import Any, Dict, List, Optional

from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field

from langchain_community.chains.graph_qa.cypher_utils import (
    CypherQueryCorrector,
    Schema,
)
from langchain_community.chains.graph_qa.prompts import (
    CYPHER_GENERATION_PROMPT,
    CYPHER_QA_PROMPT,
)
from langchain_community.graphs.graph_store import GraphStore

INTERMEDIATE_STEPS_KEY = "intermediate_steps"


[docs]def extract_cypher(text: str) -> str: """从文本中提取Cypher代码。 参数: text:要从中提取Cypher代码的文本。 返回: 从文本中提取的Cypher代码。 """ # The pattern to find Cypher code enclosed in triple backticks pattern = r"```(.*?)```" # Find all matches in the input text matches = re.findall(pattern, text, re.DOTALL) return matches[0] if matches else text
[docs]def construct_schema( structured_schema: Dict[str, Any], include_types: List[str], exclude_types: List[str], ) -> str: """根据包含或排除的类型来过滤模式""" def filter_func(x: str) -> bool: return x in include_types if include_types else x not in exclude_types filtered_schema: Dict[str, Any] = { "node_props": { k: v for k, v in structured_schema.get("node_props", {}).items() if filter_func(k) }, "rel_props": { k: v for k, v in structured_schema.get("rel_props", {}).items() if filter_func(k) }, "relationships": [ r for r in structured_schema.get("relationships", []) if all(filter_func(r[t]) for t in ["start", "end", "type"]) ], } # Format node properties formatted_node_props = [] for label, properties in filtered_schema["node_props"].items(): props_str = ", ".join( [f"{prop['property']}: {prop['type']}" for prop in properties] ) formatted_node_props.append(f"{label} {{{props_str}}}") # Format relationship properties formatted_rel_props = [] for rel_type, properties in filtered_schema["rel_props"].items(): props_str = ", ".join( [f"{prop['property']}: {prop['type']}" for prop in properties] ) formatted_rel_props.append(f"{rel_type} {{{props_str}}}") # Format relationships formatted_rels = [ f"(:{el['start']})-[:{el['type']}]->(:{el['end']})" for el in filtered_schema["relationships"] ] return "\n".join( [ "Node properties are the following:", ",".join(formatted_node_props), "Relationship properties are the following:", ",".join(formatted_rel_props), "The relationships are the following:", ",".join(formatted_rels), ] )
[docs]class GraphCypherQAChain(Chain): """用生成Cypher语句针对图形进行问答的链。 *安全提示*:确保数据库连接使用的凭据范围狭窄,仅包括必要的权限。 如果未能这样做,可能会导致数据损坏或丢失,因为调用代码可能会尝试执行会导致删除、变异数据或在适当提示的情况下读取敏感数据的命令,如果数据库中存在这样的数据。 防范这种负面结果的最佳方法是(根据需要)限制授予此工具使用的凭据的权限。 有关更多信息,请参见https://python.langchain.com/docs/security。""" graph: GraphStore = Field(exclude=True) cypher_generation_chain: LLMChain qa_chain: LLMChain graph_schema: str input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: top_k: int = 10 """查询返回的结果数量""" return_intermediate_steps: bool = False """是否返回中间步骤以及最终答案。""" return_direct: bool = False """是否直接返回查询图形的结果。""" cypher_query_corrector: Optional[CypherQueryCorrector] = None """可选密码验证工具""" @property def input_keys(self) -> List[str]: """返回输入的键。 :元数据 私有: """ return [self.input_key] @property def output_keys(self) -> List[str]: """返回输出键。 :元数据 私有: """ _output_keys = [self.output_key] return _output_keys @property def _chain_type(self) -> str: return "graph_cypher_chain"
[docs] @classmethod def from_llm( cls, llm: Optional[BaseLanguageModel] = None, *, qa_prompt: Optional[BasePromptTemplate] = None, cypher_prompt: Optional[BasePromptTemplate] = None, cypher_llm: Optional[BaseLanguageModel] = None, qa_llm: Optional[BaseLanguageModel] = None, exclude_types: List[str] = [], include_types: List[str] = [], validate_cypher: bool = False, qa_llm_kwargs: Optional[Dict[str, Any]] = None, cypher_llm_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> GraphCypherQAChain: """从LLM初始化。""" if not cypher_llm and not llm: raise ValueError("Either `llm` or `cypher_llm` parameters must be provided") if not qa_llm and not llm: raise ValueError("Either `llm` or `qa_llm` parameters must be provided") if cypher_llm and qa_llm and llm: raise ValueError( "You can specify up to two of 'cypher_llm', 'qa_llm'" ", and 'llm', but not all three simultaneously." ) if cypher_prompt and cypher_llm_kwargs: raise ValueError( "Specifying cypher_prompt and cypher_llm_kwargs together is" " not allowed. Please pass prompt via cypher_llm_kwargs." ) if qa_prompt and qa_llm_kwargs: raise ValueError( "Specifying qa_prompt and qa_llm_kwargs together is" " not allowed. Please pass prompt via qa_llm_kwargs." ) use_qa_llm_kwargs = qa_llm_kwargs if qa_llm_kwargs is not None else {} use_cypher_llm_kwargs = ( cypher_llm_kwargs if cypher_llm_kwargs is not None else {} ) if "prompt" not in use_qa_llm_kwargs: use_qa_llm_kwargs["prompt"] = ( qa_prompt if qa_prompt is not None else CYPHER_QA_PROMPT ) if "prompt" not in use_cypher_llm_kwargs: use_cypher_llm_kwargs["prompt"] = ( cypher_prompt if cypher_prompt is not None else CYPHER_GENERATION_PROMPT ) qa_chain = LLMChain(llm=qa_llm or llm, **use_qa_llm_kwargs) # type: ignore[arg-type] cypher_generation_chain = LLMChain( llm=cypher_llm or llm, # type: ignore[arg-type] **use_cypher_llm_kwargs, # type: ignore[arg-type] ) if exclude_types and include_types: raise ValueError( "Either `exclude_types` or `include_types` " "can be provided, but not both" ) graph_schema = construct_schema( kwargs["graph"].get_structured_schema, include_types, exclude_types ) cypher_query_corrector = None if validate_cypher: corrector_schema = [ Schema(el["start"], el["type"], el["end"]) for el in kwargs["graph"].structured_schema.get("relationships") ] cypher_query_corrector = CypherQueryCorrector(corrector_schema) return cls( graph_schema=graph_schema, qa_chain=qa_chain, cypher_generation_chain=cypher_generation_chain, cypher_query_corrector=cypher_query_corrector, **kwargs, )
def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: """生成Cypher语句,使用它在数据库中查找并回答问题。""" _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() callbacks = _run_manager.get_child() question = inputs[self.input_key] intermediate_steps: List = [] generated_cypher = self.cypher_generation_chain.run( {"question": question, "schema": self.graph_schema}, callbacks=callbacks ) # Extract Cypher code if it is wrapped in backticks generated_cypher = extract_cypher(generated_cypher) # Correct Cypher query if enabled if self.cypher_query_corrector: generated_cypher = self.cypher_query_corrector(generated_cypher) _run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose) _run_manager.on_text( generated_cypher, color="green", end="\n", verbose=self.verbose ) intermediate_steps.append({"query": generated_cypher}) # Retrieve and limit the number of results # Generated Cypher be null if query corrector identifies invalid schema if generated_cypher: context = self.graph.query(generated_cypher)[: self.top_k] else: context = [] if self.return_direct: final_result = context else: _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) _run_manager.on_text( str(context), color="green", end="\n", verbose=self.verbose ) intermediate_steps.append({"context": context}) result = self.qa_chain( {"question": question, "context": context}, callbacks=callbacks, ) final_result = result[self.qa_chain.output_key] chain_result: Dict[str, Any] = {self.output_key: final_result} if self.return_intermediate_steps: chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps return chain_result