"""LLM链用于将用户文本查询转换为结构化查询。"""
from __future__ import annotations
import json
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.output_parsers.json import parse_and_check_json_markdown
from langchain_core.prompts import BasePromptTemplate
from langchain_core.prompts.few_shot import FewShotPromptTemplate
from langchain_core.runnables import Runnable
from langchain_core.structured_query import (
Comparator,
Comparison,
FilterDirective,
Operation,
Operator,
StructuredQuery,
)
from langchain.chains.llm import LLMChain
from langchain.chains.query_constructor.parser import get_parser
from langchain.chains.query_constructor.prompt import (
DEFAULT_EXAMPLES,
DEFAULT_PREFIX,
DEFAULT_SCHEMA_PROMPT,
DEFAULT_SUFFIX,
EXAMPLE_PROMPT,
EXAMPLES_WITH_LIMIT,
PREFIX_WITH_DATA_SOURCE,
SCHEMA_WITH_LIMIT_PROMPT,
SUFFIX_WITHOUT_DATA_SOURCE,
USER_SPECIFIED_EXAMPLE_PROMPT,
)
from langchain.chains.query_constructor.schema import AttributeInfo
[docs]class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
"""解析结构化查询的输出解析器。"""
ast_parse: Callable
"""将字典解析为查询语言的内部表示的可调用函数。"""
[docs] def parse(self, text: str) -> StructuredQuery:
try:
expected_keys = ["query", "filter"]
allowed_keys = ["query", "filter", "limit"]
parsed = parse_and_check_json_markdown(text, expected_keys)
if parsed["query"] is None or len(parsed["query"]) == 0:
parsed["query"] = " "
if parsed["filter"] == "NO_FILTER" or not parsed["filter"]:
parsed["filter"] = None
else:
parsed["filter"] = self.ast_parse(parsed["filter"])
if not parsed.get("limit"):
parsed.pop("limit", None)
return StructuredQuery(
**{k: v for k, v in parsed.items() if k in allowed_keys}
)
except Exception as e:
raise OutputParserException(
f"Parsing text\n{text}\n raised following error:\n{e}"
)
[docs] @classmethod
def from_components(
cls,
allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]] = None,
allowed_attributes: Optional[Sequence[str]] = None,
fix_invalid: bool = False,
) -> StructuredQueryOutputParser:
"""从组件中创建一个结构化查询输出解析器。
参数:
allowed_comparators: 允许的比较器
allowed_operators: 允许的运算符
返回:
一个结构化查询输出解析器
"""
ast_parse: Callable
if fix_invalid:
def ast_parse(raw_filter: str) -> Optional[FilterDirective]:
filter = cast(Optional[FilterDirective], get_parser().parse(raw_filter))
fixed = fix_filter_directive(
filter,
allowed_comparators=allowed_comparators,
allowed_operators=allowed_operators,
allowed_attributes=allowed_attributes,
)
return fixed
else:
ast_parse = get_parser(
allowed_comparators=allowed_comparators,
allowed_operators=allowed_operators,
allowed_attributes=allowed_attributes,
).parse
return cls(ast_parse=ast_parse)
[docs]def fix_filter_directive(
filter: Optional[FilterDirective],
*,
allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]] = None,
allowed_attributes: Optional[Sequence[str]] = None,
) -> Optional[FilterDirective]:
"""修复无效的过滤指令。
参数:
filter: 需要修复的过滤指令。
allowed_comparators: 允许的比较器。默认为所有比较器。
allowed_operators: 允许的操作符。默认为所有操作符。
allowed_attributes: 允许的属性。默认为所有属性。
返回:
修复后的过滤指令。
"""
if (
not (allowed_comparators or allowed_operators or allowed_attributes)
) or not filter:
return filter
elif isinstance(filter, Comparison):
if allowed_comparators and filter.comparator not in allowed_comparators:
return None
if allowed_attributes and filter.attribute not in allowed_attributes:
return None
return filter
elif isinstance(filter, Operation):
if allowed_operators and filter.operator not in allowed_operators:
return None
args = [
cast(
FilterDirective,
fix_filter_directive(
arg,
allowed_comparators=allowed_comparators,
allowed_operators=allowed_operators,
allowed_attributes=allowed_attributes,
),
)
for arg in filter.arguments
if arg is not None
]
if not args:
return None
elif len(args) == 1 and filter.operator in (Operator.AND, Operator.OR):
return args[0]
else:
return Operation(
operator=filter.operator,
arguments=args,
)
else:
return filter
def _format_attribute_info(info: Sequence[Union[AttributeInfo, dict]]) -> str:
info_dicts = {}
for i in info:
i_dict = dict(i)
info_dicts[i_dict.pop("name")] = i_dict
return json.dumps(info_dicts, indent=4).replace("{", "{{").replace("}", "}}")
[docs]def construct_examples(input_output_pairs: Sequence[Tuple[str, dict]]) -> List[dict]:
"""从输入输出对构建示例。
参数:
input_output_pairs:输入输出对的序列。
返回:
示例列表。
"""
examples = []
for i, (_input, output) in enumerate(input_output_pairs):
structured_request = (
json.dumps(output, indent=4).replace("{", "{{").replace("}", "}}")
)
example = {
"i": i + 1,
"user_query": _input,
"structured_request": structured_request,
}
examples.append(example)
return examples
[docs]def get_query_constructor_prompt(
document_contents: str,
attribute_info: Sequence[Union[AttributeInfo, dict]],
*,
examples: Optional[Sequence] = None,
allowed_comparators: Sequence[Comparator] = tuple(Comparator),
allowed_operators: Sequence[Operator] = tuple(Operator),
enable_limit: bool = False,
schema_prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any,
) -> BasePromptTemplate:
"""创建查询构建提示。
参数:
document_contents:要查询的文档内容。
attribute_info:描述文档属性的AttributeInfo对象列表。
examples:用于链的可选示例列表。
allowed_comparators:允许的比较器序列。
allowed_operators:允许的运算符序列。
enable_limit:是否启用限制运算符。默认为False。
schema_prompt:用于描述查询模式的提示。应具有字符串输入变量allowed_comparators和allowed_operators。
**kwargs:传递给FewShotPromptTemplate init的其他命名参数。
返回:
可用于构建查询的提示模板。
"""
default_schema_prompt = (
SCHEMA_WITH_LIMIT_PROMPT if enable_limit else DEFAULT_SCHEMA_PROMPT
)
schema_prompt = schema_prompt or default_schema_prompt
attribute_str = _format_attribute_info(attribute_info)
schema = schema_prompt.format(
allowed_comparators=" | ".join(allowed_comparators),
allowed_operators=" | ".join(allowed_operators),
)
if examples and isinstance(examples[0], tuple):
examples = construct_examples(examples)
example_prompt = USER_SPECIFIED_EXAMPLE_PROMPT
prefix = PREFIX_WITH_DATA_SOURCE.format(
schema=schema, content=document_contents, attributes=attribute_str
)
suffix = SUFFIX_WITHOUT_DATA_SOURCE.format(i=len(examples) + 1)
else:
examples = examples or (
EXAMPLES_WITH_LIMIT if enable_limit else DEFAULT_EXAMPLES
)
example_prompt = EXAMPLE_PROMPT
prefix = DEFAULT_PREFIX.format(schema=schema)
suffix = DEFAULT_SUFFIX.format(
i=len(examples) + 1, content=document_contents, attributes=attribute_str
)
return FewShotPromptTemplate(
examples=list(examples),
example_prompt=example_prompt,
input_variables=["query"],
suffix=suffix,
prefix=prefix,
**kwargs,
)
[docs]def load_query_constructor_chain(
llm: BaseLanguageModel,
document_contents: str,
attribute_info: Sequence[Union[AttributeInfo, dict]],
examples: Optional[List] = None,
allowed_comparators: Sequence[Comparator] = tuple(Comparator),
allowed_operators: Sequence[Operator] = tuple(Operator),
enable_limit: bool = False,
schema_prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any,
) -> LLMChain:
"""加载一个查询构造器链。
参数:
llm: 用于该链的BaseLanguageModel。
document_contents: 要查询的文档内容。
attribute_info: 文档中属性的序列。
examples: 用于该链的可选示例列表。
allowed_comparators: 允许的比较器序列。默认为所有比较器。
allowed_operators: 允许的操作符序列。默认为所有操作符。
enable_limit: 是否启用限制操作符。默认为False。
schema_prompt: 用于描述查询模式的提示。应包含字符串输入变量allowed_comparators和allowed_operators。
**kwargs: 传递给LLMChain的任意命名参数。
返回:
可用于构造查询的LLMChain。
"""
prompt = get_query_constructor_prompt(
document_contents,
attribute_info,
examples=examples,
allowed_comparators=allowed_comparators,
allowed_operators=allowed_operators,
enable_limit=enable_limit,
schema_prompt=schema_prompt,
)
allowed_attributes = []
for ainfo in attribute_info:
allowed_attributes.append(
ainfo.name if isinstance(ainfo, AttributeInfo) else ainfo["name"]
)
output_parser = StructuredQueryOutputParser.from_components(
allowed_comparators=allowed_comparators,
allowed_operators=allowed_operators,
allowed_attributes=allowed_attributes,
)
# For backwards compatibility.
prompt.output_parser = output_parser
return LLMChain(llm=llm, prompt=prompt, output_parser=output_parser, **kwargs)
[docs]def load_query_constructor_runnable(
llm: BaseLanguageModel,
document_contents: str,
attribute_info: Sequence[Union[AttributeInfo, dict]],
*,
examples: Optional[Sequence] = None,
allowed_comparators: Sequence[Comparator] = tuple(Comparator),
allowed_operators: Sequence[Operator] = tuple(Operator),
enable_limit: bool = False,
schema_prompt: Optional[BasePromptTemplate] = None,
fix_invalid: bool = False,
**kwargs: Any,
) -> Runnable:
"""加载一个可运行的查询构造器链。
参数:
llm: 用于该链的BaseLanguageModel。
document_contents: 要查询的文档页面内容的描述。
attribute_info: 文档中属性的序列。
examples: 用于该链的可选示例列表。
allowed_comparators: 允许的比较器序列。默认为所有比较器。
allowed_operators: 允许的运算符序列。默认为所有运算符。
enable_limit: 是否启用限制运算符。默认为False。
schema_prompt: 用于描述查询模式的提示。应包含字符串输入变量allowed_comparators和allowed_operators。
fix_invalid: 是否通过忽略无效的运算符、比较器和属性来修复无效的过滤指令。
**kwargs: 传递给FewShotPromptTemplate init的其他命名参数。
返回:
可用于构造查询的Runnable。
"""
prompt = get_query_constructor_prompt(
document_contents,
attribute_info,
examples=examples,
allowed_comparators=allowed_comparators,
allowed_operators=allowed_operators,
enable_limit=enable_limit,
schema_prompt=schema_prompt,
**kwargs,
)
allowed_attributes = []
for ainfo in attribute_info:
allowed_attributes.append(
ainfo.name if isinstance(ainfo, AttributeInfo) else ainfo["name"]
)
output_parser = StructuredQueryOutputParser.from_components(
allowed_comparators=allowed_comparators,
allowed_operators=allowed_operators,
allowed_attributes=allowed_attributes,
fix_invalid=fix_invalid,
)
return prompt | llm | output_parser