Source code for langchain_core.prompts.string

"""BasePrompt模式定义。"""

from __future__ import annotations

import warnings
from abc import ABC
from string import Formatter
from typing import Any, Callable, Dict, List, Set, Tuple, Type

import langchain_core.utils.mustache as mustache
from langchain_core.prompt_values import PromptValue, StringPromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, create_model
from langchain_core.utils import get_colored_text
from langchain_core.utils.formatting import formatter
from langchain_core.utils.interactive_env import is_interactive_env


[docs]def jinja2_formatter(template: str, **kwargs: Any) -> str: """使用jinja2格式化模板。 *安全警告*: 截至LangChain 0.0.329版本,此方法默认使用Jinja2的SandboxedEnvironment。然而,这种沙盒机制应被视为尽力而为的方法,而非安全性的保证。不要接受来自不受信任来源的jinja2模板,因为它们可能导致任意Python代码执行。 https://jinja.palletsprojects.com/en/3.1.x/sandbox/ """ try: from jinja2.sandbox import SandboxedEnvironment except ImportError: raise ImportError( "jinja2 not installed, which is needed to use the jinja2_formatter. " "Please install it with `pip install jinja2`." "Please be cautious when using jinja2 templates. " "Do not expand jinja2 templates using unverified or user-controlled " "inputs as that can result in arbitrary Python code execution." ) # This uses a sandboxed environment to prevent arbitrary code execution. # Jinja2 uses an opt-out rather than opt-in approach for sand-boxing. # Please treat this sand-boxing as a best-effort approach rather than # a guarantee of security. # We recommend to never use jinja2 templates with untrusted inputs. # https://jinja.palletsprojects.com/en/3.1.x/sandbox/ # approach not a guarantee of security. return SandboxedEnvironment().from_string(template).render(**kwargs)
[docs]def validate_jinja2(template: str, input_variables: List[str]) -> None: """验证模板的输入变量是否有效。 如果发现缺少或多余的变量,则发出警告。 参数: template: 模板字符串。 input_variables: 输入变量。 """ input_variables_set = set(input_variables) valid_variables = _get_jinja2_variables_from_template(template) missing_variables = valid_variables - input_variables_set extra_variables = input_variables_set - valid_variables warning_message = "" if missing_variables: warning_message += f"Missing variables: {missing_variables} " if extra_variables: warning_message += f"Extra variables: {extra_variables}" if warning_message: warnings.warn(warning_message.strip())
def _get_jinja2_variables_from_template(template: str) -> Set[str]: try: from jinja2 import Environment, meta except ImportError: raise ImportError( "jinja2 not installed, which is needed to use the jinja2_formatter. " "Please install it with `pip install jinja2`." ) env = Environment() ast = env.parse(template) variables = meta.find_undeclared_variables(ast) return variables
[docs]def mustache_formatter(template: str, **kwargs: Any) -> str: """使用大括号进行模板格式化。""" return mustache.render(template, kwargs)
[docs]def mustache_template_vars( template: str, ) -> Set[str]: """从一个mustache模板中获取变量。""" vars: Set[str] = set() in_section = False for type, key in mustache.tokenize(template): if type == "end": in_section = False elif in_section: continue elif type in ("variable", "section") and key != ".": vars.add(key.split(".")[0]) if type == "section": in_section = True return vars
Defs = Dict[str, "Defs"]
[docs]def mustache_schema( template: str, ) -> Type[BaseModel]: """从一个mustache模板中获取变量。""" fields = set() prefix: Tuple[str, ...] = () for type, key in mustache.tokenize(template): if key == ".": continue if type == "end": prefix = prefix[: -key.count(".")] elif type == "section": prefix = prefix + tuple(key.split(".")) elif type == "variable": fields.add(prefix + tuple(key.split("."))) defs: Defs = {} # None means leaf node while fields: field = fields.pop() current = defs for part in field[:-1]: current = current.setdefault(part, {}) current[field[-1]] = {} return _create_model_recursive("PromptInput", defs)
def _create_model_recursive(name: str, defs: Defs) -> Type: return create_model( # type: ignore[call-overload] name, **{ k: (_create_model_recursive(k, v), None) if v else (str, None) for k, v in defs.items() }, ) DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = { "f-string": formatter.format, "mustache": mustache_formatter, "jinja2": jinja2_formatter, } DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = { "f-string": formatter.validate_input_variables, "jinja2": validate_jinja2, }
[docs]def check_valid_template( template: str, template_format: str, input_variables: List[str] ) -> None: """检查模板字符串是否有效。 参数: template:模板字符串。 template_format:模板格式。应为“f-string”或“jinja2”之一。 input_variables:输入变量。 引发: ValueError:如果模板格式不受支持。 """ try: validator_func = DEFAULT_VALIDATOR_MAPPING[template_format] except KeyError as exc: raise ValueError( f"Invalid template format {template_format!r}, should be one of" f" {list(DEFAULT_FORMATTER_MAPPING)}." ) from exc try: validator_func(template, input_variables) except (KeyError, IndexError) as exc: raise ValueError( "Invalid prompt schema; check for mismatched or missing input parameters" f" from {input_variables}." ) from exc
[docs]def get_template_variables(template: str, template_format: str) -> List[str]: """从模板中获取变量。 参数: template: 模板字符串。 template_format: 模板格式。应为"f-string"或"jinja2"之一。 返回: 模板中的变量。 引发: ValueError: 如果模板格式不受支持。 """ if template_format == "jinja2": # Get the variables for the template input_variables = _get_jinja2_variables_from_template(template) elif template_format == "f-string": input_variables = { v for _, v, _, _ in Formatter().parse(template) if v is not None } elif template_format == "mustache": input_variables = mustache_template_vars(template) else: raise ValueError(f"Unsupported template format: {template_format}") return sorted(input_variables)
[docs]class StringPromptTemplate(BasePromptTemplate, ABC): """返回一个暴露格式方法的字符串提示。"""
[docs] @classmethod def get_lc_namespace(cls) -> List[str]: """获取langchain对象的命名空间。""" return ["langchain", "prompts", "base"]
[docs] def format_prompt(self, **kwargs: Any) -> PromptValue: return StringPromptValue(text=self.format(**kwargs))
[docs] async def aformat_prompt(self, **kwargs: Any) -> PromptValue: return StringPromptValue(text=await self.aformat(**kwargs))
[docs] def pretty_repr(self, html: bool = False) -> str: # TODO: handle partials dummy_vars = { input_var: "{" + f"{input_var}" + "}" for input_var in self.input_variables } if html: dummy_vars = { k: get_colored_text(v, "yellow") for k, v in dummy_vars.items() } return self.format(**dummy_vars)
[docs] def pretty_print(self) -> None: print(self.pretty_repr(html=is_interactive_env())) # noqa: T201