from __future__ import annotations
import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generic,
List,
Mapping,
Optional,
Type,
TypeVar,
Union,
)
import yaml
from langchain_core.output_parsers.base import BaseOutputParser
from langchain_core.prompt_values import (
ChatPromptValueConcrete,
PromptValue,
StringPromptValue,
)
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.runnables.config import ensure_config
from langchain_core.runnables.utils import create_model
if TYPE_CHECKING:
from langchain_core.documents import Document
FormatOutputType = TypeVar("FormatOutputType")
[docs]class BasePromptTemplate(
RunnableSerializable[Dict, PromptValue], Generic[FormatOutputType], ABC
):
"""所有提示模板的基类,返回一个提示。"""
input_variables: List[str]
"""提示模板期望的变量名称列表。"""
input_types: Dict[str, Any] = Field(default_factory=dict)
"""一个包含提示模板期望的变量类型的字典。
如果未提供,则假定所有变量都是字符串。"""
output_parser: Optional[BaseOutputParser] = None
"""如何解析调用LLM后的格式化提示的输出。"""
partial_variables: Mapping[str, Any] = Field(default_factory=dict)
"""模板中包含的部分变量的字典。
部分变量用于填充模板,这样每次调用提示时就不需要传递它们。"""
metadata: Optional[Dict[str, Any]] = None
"""用于跟踪的元数据。"""
tags: Optional[List[str]] = None
"""用于跟踪的标签。"""
[docs] @classmethod
def get_lc_namespace(cls) -> List[str]:
"""获取langchain对象的命名空间。"""
return ["langchain", "schema", "prompt_template"]
[docs] @classmethod
def is_lc_serializable(cls) -> bool:
"""返回此类是否可序列化。"""
return True
class Config:
"""这个pydantic对象的配置。"""
arbitrary_types_allowed = True
@property
def OutputType(self) -> Any:
return Union[StringPromptValue, ChatPromptValueConcrete]
def _validate_input(self, inner_input: Dict) -> Dict:
if not isinstance(inner_input, dict):
if len(self.input_variables) == 1:
var_name = self.input_variables[0]
inner_input = {var_name: inner_input}
else:
raise TypeError(
f"Expected mapping type as input to {self.__class__.__name__}. "
f"Received {type(inner_input)}."
)
missing = set(self.input_variables).difference(inner_input)
if missing:
raise KeyError(
f"Input to {self.__class__.__name__} is missing variables {missing}. "
f" Expected: {self.input_variables}"
f" Received: {list(inner_input.keys())}"
)
return inner_input
def _format_prompt_with_error_handling(self, inner_input: Dict) -> PromptValue:
_inner_input = self._validate_input(inner_input)
return self.format_prompt(**_inner_input)
async def _aformat_prompt_with_error_handling(
self, inner_input: Dict
) -> PromptValue:
_inner_input = self._validate_input(inner_input)
return await self.aformat_prompt(**_inner_input)
[docs] def invoke(
self, input: Dict, config: Optional[RunnableConfig] = None
) -> PromptValue:
config = ensure_config(config)
if self.metadata:
config["metadata"] = {**config["metadata"], **self.metadata}
if self.tags:
config["tags"] = config["tags"] + self.tags
return self._call_with_config(
self._format_prompt_with_error_handling,
input,
config,
run_type="prompt",
)
[docs] async def ainvoke(
self, input: Dict, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> PromptValue:
config = ensure_config(config)
if self.metadata:
config["metadata"].update(self.metadata)
if self.tags:
config["tags"].extend(self.tags)
return await self._acall_with_config(
self._aformat_prompt_with_error_handling,
input,
config,
run_type="prompt",
)
@root_validator()
def validate_variable_names(cls, values: Dict) -> Dict:
"""验证变量名不包含受限制的名称。"""
if "stop" in values["input_variables"]:
raise ValueError(
"Cannot have an input variable named 'stop', as it is used internally,"
" please rename."
)
if "stop" in values["partial_variables"]:
raise ValueError(
"Cannot have an partial variable named 'stop', as it is used "
"internally, please rename."
)
overall = set(values["input_variables"]).intersection(
values["partial_variables"]
)
if overall:
raise ValueError(
f"Found overlapping input and partial variables: {overall}"
)
return values
[docs] def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
"""返回提示模板的部分。"""
prompt_dict = self.__dict__.copy()
prompt_dict["input_variables"] = list(
set(self.input_variables).difference(kwargs)
)
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
return type(self)(**prompt_dict)
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
# Get partial params:
partial_kwargs = {
k: v if not callable(v) else v() for k, v in self.partial_variables.items()
}
return {**partial_kwargs, **kwargs}
@property
def _prompt_type(self) -> str:
"""返回提示类型键。"""
raise NotImplementedError
[docs] def dict(self, **kwargs: Any) -> Dict:
"""返回提示的字典表示形式。"""
prompt_dict = super().dict(**kwargs)
try:
prompt_dict["_type"] = self._prompt_type
except NotImplementedError:
pass
return prompt_dict
[docs] def save(self, file_path: Union[Path, str]) -> None:
"""保存提示信息。
参数:
file_path:保存提示信息的目录路径。
示例:
.. code-block:: python
prompt.save(file_path="path/prompt.yaml")
"""
if self.partial_variables:
raise ValueError("Cannot save prompt with partial variables.")
# Fetch dictionary to save
prompt_dict = self.dict()
if "_type" not in prompt_dict:
raise NotImplementedError(f"Prompt {self} does not support saving.")
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(prompt_dict, f, indent=4)
elif save_path.suffix.endswith((".yaml", ".yml")):
with open(file_path, "w") as f:
yaml.dump(prompt_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")
def _get_document_info(doc: Document, prompt: BasePromptTemplate[str]) -> Dict:
base_info = {"page_content": doc.page_content, **doc.metadata}
missing_metadata = set(prompt.input_variables).difference(base_info)
if len(missing_metadata) > 0:
required_metadata = [
iv for iv in prompt.input_variables if iv != "page_content"
]
raise ValueError(
f"Document prompt requires documents to have metadata variables: "
f"{required_metadata}. Received document with missing metadata: "
f"{list(missing_metadata)}."
)
return {k: base_info[k] for k in prompt.input_variables}