"""BasePrompt模式定义。"""
from __future__ import annotations
import warnings
from abc import ABC
from string import Formatter
from typing import Any, Callable, Dict, List, Set, Tuple, Type
import langchain_core.utils.mustache as mustache
from langchain_core.prompt_values import PromptValue, StringPromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, create_model
from langchain_core.utils import get_colored_text
from langchain_core.utils.formatting import formatter
from langchain_core.utils.interactive_env import is_interactive_env
[docs]def validate_jinja2(template: str, input_variables: List[str]) -> None:
"""验证模板的输入变量是否有效。
如果发现缺少或多余的变量,则发出警告。
参数:
template: 模板字符串。
input_variables: 输入变量。
"""
input_variables_set = set(input_variables)
valid_variables = _get_jinja2_variables_from_template(template)
missing_variables = valid_variables - input_variables_set
extra_variables = input_variables_set - valid_variables
warning_message = ""
if missing_variables:
warning_message += f"Missing variables: {missing_variables} "
if extra_variables:
warning_message += f"Extra variables: {extra_variables}"
if warning_message:
warnings.warn(warning_message.strip())
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
try:
from jinja2 import Environment, meta
except ImportError:
raise ImportError(
"jinja2 not installed, which is needed to use the jinja2_formatter. "
"Please install it with `pip install jinja2`."
)
env = Environment()
ast = env.parse(template)
variables = meta.find_undeclared_variables(ast)
return variables
[docs]def mustache_template_vars(
template: str,
) -> Set[str]:
"""从一个mustache模板中获取变量。"""
vars: Set[str] = set()
in_section = False
for type, key in mustache.tokenize(template):
if type == "end":
in_section = False
elif in_section:
continue
elif type in ("variable", "section") and key != ".":
vars.add(key.split(".")[0])
if type == "section":
in_section = True
return vars
Defs = Dict[str, "Defs"]
[docs]def mustache_schema(
template: str,
) -> Type[BaseModel]:
"""从一个mustache模板中获取变量。"""
fields = set()
prefix: Tuple[str, ...] = ()
for type, key in mustache.tokenize(template):
if key == ".":
continue
if type == "end":
prefix = prefix[: -key.count(".")]
elif type == "section":
prefix = prefix + tuple(key.split("."))
elif type == "variable":
fields.add(prefix + tuple(key.split(".")))
defs: Defs = {} # None means leaf node
while fields:
field = fields.pop()
current = defs
for part in field[:-1]:
current = current.setdefault(part, {})
current[field[-1]] = {}
return _create_model_recursive("PromptInput", defs)
def _create_model_recursive(name: str, defs: Defs) -> Type:
return create_model( # type: ignore[call-overload]
name,
**{
k: (_create_model_recursive(k, v), None) if v else (str, None)
for k, v in defs.items()
},
)
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
"f-string": formatter.format,
"mustache": mustache_formatter,
"jinja2": jinja2_formatter,
}
DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = {
"f-string": formatter.validate_input_variables,
"jinja2": validate_jinja2,
}
[docs]def check_valid_template(
template: str, template_format: str, input_variables: List[str]
) -> None:
"""检查模板字符串是否有效。
参数:
template:模板字符串。
template_format:模板格式。应为“f-string”或“jinja2”之一。
input_variables:输入变量。
引发:
ValueError:如果模板格式不受支持。
"""
try:
validator_func = DEFAULT_VALIDATOR_MAPPING[template_format]
except KeyError as exc:
raise ValueError(
f"Invalid template format {template_format!r}, should be one of"
f" {list(DEFAULT_FORMATTER_MAPPING)}."
) from exc
try:
validator_func(template, input_variables)
except (KeyError, IndexError) as exc:
raise ValueError(
"Invalid prompt schema; check for mismatched or missing input parameters"
f" from {input_variables}."
) from exc
[docs]def get_template_variables(template: str, template_format: str) -> List[str]:
"""从模板中获取变量。
参数:
template: 模板字符串。
template_format: 模板格式。应为"f-string"或"jinja2"之一。
返回:
模板中的变量。
引发:
ValueError: 如果模板格式不受支持。
"""
if template_format == "jinja2":
# Get the variables for the template
input_variables = _get_jinja2_variables_from_template(template)
elif template_format == "f-string":
input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None
}
elif template_format == "mustache":
input_variables = mustache_template_vars(template)
else:
raise ValueError(f"Unsupported template format: {template_format}")
return sorted(input_variables)
[docs]class StringPromptTemplate(BasePromptTemplate, ABC):
"""返回一个暴露格式方法的字符串提示。"""
[docs] @classmethod
def get_lc_namespace(cls) -> List[str]:
"""获取langchain对象的命名空间。"""
return ["langchain", "prompts", "base"]
[docs] def pretty_repr(self, html: bool = False) -> str:
# TODO: handle partials
dummy_vars = {
input_var: "{" + f"{input_var}" + "}" for input_var in self.input_variables
}
if html:
dummy_vars = {
k: get_colored_text(v, "yellow") for k, v in dummy_vars.items()
}
return self.format(**dummy_vars)
[docs] def pretty_print(self) -> None:
print(self.pretty_repr(html=is_interactive_env())) # noqa: T201