Source code for langchain.chains.query_constructor.base

"""LLM链用于将用户文本查询转换为结构化查询。"""
from __future__ import annotations

import json
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast

from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.output_parsers.json import parse_and_check_json_markdown
from langchain_core.prompts import BasePromptTemplate
from langchain_core.prompts.few_shot import FewShotPromptTemplate
from langchain_core.runnables import Runnable
from langchain_core.structured_query import (
    Comparator,
    Comparison,
    FilterDirective,
    Operation,
    Operator,
    StructuredQuery,
)

from langchain.chains.llm import LLMChain
from langchain.chains.query_constructor.parser import get_parser
from langchain.chains.query_constructor.prompt import (
    DEFAULT_EXAMPLES,
    DEFAULT_PREFIX,
    DEFAULT_SCHEMA_PROMPT,
    DEFAULT_SUFFIX,
    EXAMPLE_PROMPT,
    EXAMPLES_WITH_LIMIT,
    PREFIX_WITH_DATA_SOURCE,
    SCHEMA_WITH_LIMIT_PROMPT,
    SUFFIX_WITHOUT_DATA_SOURCE,
    USER_SPECIFIED_EXAMPLE_PROMPT,
)
from langchain.chains.query_constructor.schema import AttributeInfo


[docs]class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]): """解析结构化查询的输出解析器。""" ast_parse: Callable """将字典解析为查询语言的内部表示的可调用函数。"""
[docs] def parse(self, text: str) -> StructuredQuery: try: expected_keys = ["query", "filter"] allowed_keys = ["query", "filter", "limit"] parsed = parse_and_check_json_markdown(text, expected_keys) if parsed["query"] is None or len(parsed["query"]) == 0: parsed["query"] = " " if parsed["filter"] == "NO_FILTER" or not parsed["filter"]: parsed["filter"] = None else: parsed["filter"] = self.ast_parse(parsed["filter"]) if not parsed.get("limit"): parsed.pop("limit", None) return StructuredQuery( **{k: v for k, v in parsed.items() if k in allowed_keys} ) except Exception as e: raise OutputParserException( f"Parsing text\n{text}\n raised following error:\n{e}" )
[docs] @classmethod def from_components( cls, allowed_comparators: Optional[Sequence[Comparator]] = None, allowed_operators: Optional[Sequence[Operator]] = None, allowed_attributes: Optional[Sequence[str]] = None, fix_invalid: bool = False, ) -> StructuredQueryOutputParser: """从组件中创建一个结构化查询输出解析器。 参数: allowed_comparators: 允许的比较器 allowed_operators: 允许的运算符 返回: 一个结构化查询输出解析器 """ ast_parse: Callable if fix_invalid: def ast_parse(raw_filter: str) -> Optional[FilterDirective]: filter = cast(Optional[FilterDirective], get_parser().parse(raw_filter)) fixed = fix_filter_directive( filter, allowed_comparators=allowed_comparators, allowed_operators=allowed_operators, allowed_attributes=allowed_attributes, ) return fixed else: ast_parse = get_parser( allowed_comparators=allowed_comparators, allowed_operators=allowed_operators, allowed_attributes=allowed_attributes, ).parse return cls(ast_parse=ast_parse)
[docs]def fix_filter_directive( filter: Optional[FilterDirective], *, allowed_comparators: Optional[Sequence[Comparator]] = None, allowed_operators: Optional[Sequence[Operator]] = None, allowed_attributes: Optional[Sequence[str]] = None, ) -> Optional[FilterDirective]: """修复无效的过滤指令。 参数: filter: 需要修复的过滤指令。 allowed_comparators: 允许的比较器。默认为所有比较器。 allowed_operators: 允许的操作符。默认为所有操作符。 allowed_attributes: 允许的属性。默认为所有属性。 返回: 修复后的过滤指令。 """ if ( not (allowed_comparators or allowed_operators or allowed_attributes) ) or not filter: return filter elif isinstance(filter, Comparison): if allowed_comparators and filter.comparator not in allowed_comparators: return None if allowed_attributes and filter.attribute not in allowed_attributes: return None return filter elif isinstance(filter, Operation): if allowed_operators and filter.operator not in allowed_operators: return None args = [ cast( FilterDirective, fix_filter_directive( arg, allowed_comparators=allowed_comparators, allowed_operators=allowed_operators, allowed_attributes=allowed_attributes, ), ) for arg in filter.arguments if arg is not None ] if not args: return None elif len(args) == 1 and filter.operator in (Operator.AND, Operator.OR): return args[0] else: return Operation( operator=filter.operator, arguments=args, ) else: return filter
def _format_attribute_info(info: Sequence[Union[AttributeInfo, dict]]) -> str: info_dicts = {} for i in info: i_dict = dict(i) info_dicts[i_dict.pop("name")] = i_dict return json.dumps(info_dicts, indent=4).replace("{", "{{").replace("}", "}}")
[docs]def construct_examples(input_output_pairs: Sequence[Tuple[str, dict]]) -> List[dict]: """从输入输出对构建示例。 参数: input_output_pairs:输入输出对的序列。 返回: 示例列表。 """ examples = [] for i, (_input, output) in enumerate(input_output_pairs): structured_request = ( json.dumps(output, indent=4).replace("{", "{{").replace("}", "}}") ) example = { "i": i + 1, "user_query": _input, "structured_request": structured_request, } examples.append(example) return examples
[docs]def get_query_constructor_prompt( document_contents: str, attribute_info: Sequence[Union[AttributeInfo, dict]], *, examples: Optional[Sequence] = None, allowed_comparators: Sequence[Comparator] = tuple(Comparator), allowed_operators: Sequence[Operator] = tuple(Operator), enable_limit: bool = False, schema_prompt: Optional[BasePromptTemplate] = None, **kwargs: Any, ) -> BasePromptTemplate: """创建查询构建提示。 参数: document_contents:要查询的文档内容。 attribute_info:描述文档属性的AttributeInfo对象列表。 examples:用于链的可选示例列表。 allowed_comparators:允许的比较器序列。 allowed_operators:允许的运算符序列。 enable_limit:是否启用限制运算符。默认为False。 schema_prompt:用于描述查询模式的提示。应具有字符串输入变量allowed_comparators和allowed_operators。 **kwargs:传递给FewShotPromptTemplate init的其他命名参数。 返回: 可用于构建查询的提示模板。 """ default_schema_prompt = ( SCHEMA_WITH_LIMIT_PROMPT if enable_limit else DEFAULT_SCHEMA_PROMPT ) schema_prompt = schema_prompt or default_schema_prompt attribute_str = _format_attribute_info(attribute_info) schema = schema_prompt.format( allowed_comparators=" | ".join(allowed_comparators), allowed_operators=" | ".join(allowed_operators), ) if examples and isinstance(examples[0], tuple): examples = construct_examples(examples) example_prompt = USER_SPECIFIED_EXAMPLE_PROMPT prefix = PREFIX_WITH_DATA_SOURCE.format( schema=schema, content=document_contents, attributes=attribute_str ) suffix = SUFFIX_WITHOUT_DATA_SOURCE.format(i=len(examples) + 1) else: examples = examples or ( EXAMPLES_WITH_LIMIT if enable_limit else DEFAULT_EXAMPLES ) example_prompt = EXAMPLE_PROMPT prefix = DEFAULT_PREFIX.format(schema=schema) suffix = DEFAULT_SUFFIX.format( i=len(examples) + 1, content=document_contents, attributes=attribute_str ) return FewShotPromptTemplate( examples=list(examples), example_prompt=example_prompt, input_variables=["query"], suffix=suffix, prefix=prefix, **kwargs, )
[docs]def load_query_constructor_chain( llm: BaseLanguageModel, document_contents: str, attribute_info: Sequence[Union[AttributeInfo, dict]], examples: Optional[List] = None, allowed_comparators: Sequence[Comparator] = tuple(Comparator), allowed_operators: Sequence[Operator] = tuple(Operator), enable_limit: bool = False, schema_prompt: Optional[BasePromptTemplate] = None, **kwargs: Any, ) -> LLMChain: """加载一个查询构造器链。 参数: llm: 用于该链的BaseLanguageModel。 document_contents: 要查询的文档内容。 attribute_info: 文档中属性的序列。 examples: 用于该链的可选示例列表。 allowed_comparators: 允许的比较器序列。默认为所有比较器。 allowed_operators: 允许的操作符序列。默认为所有操作符。 enable_limit: 是否启用限制操作符。默认为False。 schema_prompt: 用于描述查询模式的提示。应包含字符串输入变量allowed_comparators和allowed_operators。 **kwargs: 传递给LLMChain的任意命名参数。 返回: 可用于构造查询的LLMChain。 """ prompt = get_query_constructor_prompt( document_contents, attribute_info, examples=examples, allowed_comparators=allowed_comparators, allowed_operators=allowed_operators, enable_limit=enable_limit, schema_prompt=schema_prompt, ) allowed_attributes = [] for ainfo in attribute_info: allowed_attributes.append( ainfo.name if isinstance(ainfo, AttributeInfo) else ainfo["name"] ) output_parser = StructuredQueryOutputParser.from_components( allowed_comparators=allowed_comparators, allowed_operators=allowed_operators, allowed_attributes=allowed_attributes, ) # For backwards compatibility. prompt.output_parser = output_parser return LLMChain(llm=llm, prompt=prompt, output_parser=output_parser, **kwargs)
[docs]def load_query_constructor_runnable( llm: BaseLanguageModel, document_contents: str, attribute_info: Sequence[Union[AttributeInfo, dict]], *, examples: Optional[Sequence] = None, allowed_comparators: Sequence[Comparator] = tuple(Comparator), allowed_operators: Sequence[Operator] = tuple(Operator), enable_limit: bool = False, schema_prompt: Optional[BasePromptTemplate] = None, fix_invalid: bool = False, **kwargs: Any, ) -> Runnable: """加载一个可运行的查询构造器链。 参数: llm: 用于该链的BaseLanguageModel。 document_contents: 要查询的文档页面内容的描述。 attribute_info: 文档中属性的序列。 examples: 用于该链的可选示例列表。 allowed_comparators: 允许的比较器序列。默认为所有比较器。 allowed_operators: 允许的运算符序列。默认为所有运算符。 enable_limit: 是否启用限制运算符。默认为False。 schema_prompt: 用于描述查询模式的提示。应包含字符串输入变量allowed_comparators和allowed_operators。 fix_invalid: 是否通过忽略无效的运算符、比较器和属性来修复无效的过滤指令。 **kwargs: 传递给FewShotPromptTemplate init的其他命名参数。 返回: 可用于构造查询的Runnable。 """ prompt = get_query_constructor_prompt( document_contents, attribute_info, examples=examples, allowed_comparators=allowed_comparators, allowed_operators=allowed_operators, enable_limit=enable_limit, schema_prompt=schema_prompt, **kwargs, ) allowed_attributes = [] for ainfo in attribute_info: allowed_attributes.append( ainfo.name if isinstance(ainfo, AttributeInfo) else ainfo["name"] ) output_parser = StructuredQueryOutputParser.from_components( allowed_comparators=allowed_comparators, allowed_operators=allowed_operators, allowed_attributes=allowed_attributes, fix_invalid=fix_invalid, ) return prompt | llm | output_parser