Source code for langchain_core.prompts.base

from __future__ import annotations

import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    Generic,
    List,
    Mapping,
    Optional,
    Type,
    TypeVar,
    Union,
)

import yaml

from langchain_core.output_parsers.base import BaseOutputParser
from langchain_core.prompt_values import (
    ChatPromptValueConcrete,
    PromptValue,
    StringPromptValue,
)
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.runnables.config import ensure_config
from langchain_core.runnables.utils import create_model

if TYPE_CHECKING:
    from langchain_core.documents import Document


FormatOutputType = TypeVar("FormatOutputType")


[docs]class BasePromptTemplate( RunnableSerializable[Dict, PromptValue], Generic[FormatOutputType], ABC ): """所有提示模板的基类,返回一个提示。""" input_variables: List[str] """提示模板期望的变量名称列表。""" input_types: Dict[str, Any] = Field(default_factory=dict) """一个包含提示模板期望的变量类型的字典。 如果未提供,则假定所有变量都是字符串。""" output_parser: Optional[BaseOutputParser] = None """如何解析调用LLM后的格式化提示的输出。""" partial_variables: Mapping[str, Any] = Field(default_factory=dict) """模板中包含的部分变量的字典。 部分变量用于填充模板,这样每次调用提示时就不需要传递它们。""" metadata: Optional[Dict[str, Any]] = None """用于跟踪的元数据。""" tags: Optional[List[str]] = None """用于跟踪的标签。"""
[docs] @classmethod def get_lc_namespace(cls) -> List[str]: """获取langchain对象的命名空间。""" return ["langchain", "schema", "prompt_template"]
[docs] @classmethod def is_lc_serializable(cls) -> bool: """返回此类是否可序列化。""" return True
class Config: """这个pydantic对象的配置。""" arbitrary_types_allowed = True @property def OutputType(self) -> Any: return Union[StringPromptValue, ChatPromptValueConcrete]
[docs] def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> Type[BaseModel]: # This is correct, but pydantic typings/mypy don't think so. return create_model( # type: ignore[call-overload] "PromptInput", **{k: (self.input_types.get(k, str), None) for k in self.input_variables}, )
def _validate_input(self, inner_input: Dict) -> Dict: if not isinstance(inner_input, dict): if len(self.input_variables) == 1: var_name = self.input_variables[0] inner_input = {var_name: inner_input} else: raise TypeError( f"Expected mapping type as input to {self.__class__.__name__}. " f"Received {type(inner_input)}." ) missing = set(self.input_variables).difference(inner_input) if missing: raise KeyError( f"Input to {self.__class__.__name__} is missing variables {missing}. " f" Expected: {self.input_variables}" f" Received: {list(inner_input.keys())}" ) return inner_input def _format_prompt_with_error_handling(self, inner_input: Dict) -> PromptValue: _inner_input = self._validate_input(inner_input) return self.format_prompt(**_inner_input) async def _aformat_prompt_with_error_handling( self, inner_input: Dict ) -> PromptValue: _inner_input = self._validate_input(inner_input) return await self.aformat_prompt(**_inner_input)
[docs] def invoke( self, input: Dict, config: Optional[RunnableConfig] = None ) -> PromptValue: config = ensure_config(config) if self.metadata: config["metadata"] = {**config["metadata"], **self.metadata} if self.tags: config["tags"] = config["tags"] + self.tags return self._call_with_config( self._format_prompt_with_error_handling, input, config, run_type="prompt", )
[docs] async def ainvoke( self, input: Dict, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> PromptValue: config = ensure_config(config) if self.metadata: config["metadata"].update(self.metadata) if self.tags: config["tags"].extend(self.tags) return await self._acall_with_config( self._aformat_prompt_with_error_handling, input, config, run_type="prompt", )
[docs] @abstractmethod def format_prompt(self, **kwargs: Any) -> PromptValue: """创建提示值。"""
[docs] async def aformat_prompt(self, **kwargs: Any) -> PromptValue: """创建提示值。""" return self.format_prompt(**kwargs)
@root_validator() def validate_variable_names(cls, values: Dict) -> Dict: """验证变量名不包含受限制的名称。""" if "stop" in values["input_variables"]: raise ValueError( "Cannot have an input variable named 'stop', as it is used internally," " please rename." ) if "stop" in values["partial_variables"]: raise ValueError( "Cannot have an partial variable named 'stop', as it is used " "internally, please rename." ) overall = set(values["input_variables"]).intersection( values["partial_variables"] ) if overall: raise ValueError( f"Found overlapping input and partial variables: {overall}" ) return values
[docs] def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate: """返回提示模板的部分。""" prompt_dict = self.__dict__.copy() prompt_dict["input_variables"] = list( set(self.input_variables).difference(kwargs) ) prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs} return type(self)(**prompt_dict)
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]: # Get partial params: partial_kwargs = { k: v if not callable(v) else v() for k, v in self.partial_variables.items() } return {**partial_kwargs, **kwargs}
[docs] @abstractmethod def format(self, **kwargs: Any) -> FormatOutputType: """格式化带有输入的提示。 参数: kwargs: 要传递给提示模板的任何参数。 返回: 格式化后的字符串。 示例: .. code-block:: python prompt.format(variable1="foo") """
[docs] async def aformat(self, **kwargs: Any) -> FormatOutputType: """根据输入格式化提示信息。 参数: kwargs: 要传递给提示模板的任何参数。 返回: 格式化后的字符串。 示例: .. code-block:: python await prompt.aformat(variable1="foo") """ return self.format(**kwargs)
@property def _prompt_type(self) -> str: """返回提示类型键。""" raise NotImplementedError
[docs] def dict(self, **kwargs: Any) -> Dict: """返回提示的字典表示形式。""" prompt_dict = super().dict(**kwargs) try: prompt_dict["_type"] = self._prompt_type except NotImplementedError: pass return prompt_dict
[docs] def save(self, file_path: Union[Path, str]) -> None: """保存提示信息。 参数: file_path:保存提示信息的目录路径。 示例: .. code-block:: python prompt.save(file_path="path/prompt.yaml") """ if self.partial_variables: raise ValueError("Cannot save prompt with partial variables.") # Fetch dictionary to save prompt_dict = self.dict() if "_type" not in prompt_dict: raise NotImplementedError(f"Prompt {self} does not support saving.") # Convert file to Path object. if isinstance(file_path, str): save_path = Path(file_path) else: save_path = file_path directory_path = save_path.parent directory_path.mkdir(parents=True, exist_ok=True) if save_path.suffix == ".json": with open(file_path, "w") as f: json.dump(prompt_dict, f, indent=4) elif save_path.suffix.endswith((".yaml", ".yml")): with open(file_path, "w") as f: yaml.dump(prompt_dict, f, default_flow_style=False) else: raise ValueError(f"{save_path} must be json or yaml")
def _get_document_info(doc: Document, prompt: BasePromptTemplate[str]) -> Dict: base_info = {"page_content": doc.page_content, **doc.metadata} missing_metadata = set(prompt.input_variables).difference(base_info) if len(missing_metadata) > 0: required_metadata = [ iv for iv in prompt.input_variables if iv != "page_content" ] raise ValueError( f"Document prompt requires documents to have metadata variables: " f"{required_metadata}. Received document with missing metadata: " f"{list(missing_metadata)}." ) return {k: base_info[k] for k in prompt.input_variables}
[docs]def format_document(doc: Document, prompt: BasePromptTemplate[str]) -> str: """将文档根据提示模板格式化为字符串。 首先,这从两个来源的文档中提取信息: 1. `page_content`: 这获取`document.page_content`中的信息 并将其分配给名为`page_content`的变量。 2. metadata: 这从`document.metadata`中获取信息并分配 给同名的变量。 然后将这些变量传递到`prompt`中以生成格式化的字符串。 参数: doc: Document,将使用page_content和metadata创建最终字符串。 prompt: BasePromptTemplate,将用于将page_content和metadata格式化为最终字符串。 返回: 格式化的文档字符串。 示例: .. code-block:: python from langchain_core.documents import Document from langchain_core.prompts import PromptTemplate doc = Document(page_content="This is a joke", metadata={"page": "1"}) prompt = PromptTemplate.from_template("Page {page}: {page_content}") format_document(doc, prompt) >>> "Page 1: This is a joke" """ return prompt.format(**_get_document_info(doc, prompt))
[docs]async def aformat_document(doc: Document, prompt: BasePromptTemplate[str]) -> str: """将文档根据提示模板格式化为字符串。 首先,这从两个来源提取文档信息: 1. `page_content`: 这获取来自`document.page_content`的信息 并将其分配给名为`page_content`的变量。 2. metadata: 这从`document.metadata`获取信息并分配 给同名的变量。 然后将这些变量传递到`prompt`中,以生成格式化的字符串。 参数: doc:Document,将使用page_content和metadata创建最终字符串。 prompt:BasePromptTemplate,将用于将page_content和metadata格式化为最终字符串。 返回: 文档格式化后的字符串。 """ return await prompt.aformat(**_get_document_info(doc, prompt))