from __future__ import annotations
import contextlib
import json
import typing
from abc import ABC, abstractmethod
from collections.abc import Mapping
from functools import cached_property
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Optional,
TypeVar,
Union,
)
import yaml
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Self, override
from langchain_core.exceptions import ErrorCode, create_message
from langchain_core.load import dumpd
from langchain_core.output_parsers.base import BaseOutputParser
from langchain_core.prompt_values import (
ChatPromptValueConcrete,
PromptValue,
StringPromptValue,
)
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.runnables.config import ensure_config
from langchain_core.utils.pydantic import create_model_v2
if TYPE_CHECKING:
from langchain_core.documents import Document
FormatOutputType = TypeVar("FormatOutputType")
[docs]
class BasePromptTemplate(
RunnableSerializable[dict, PromptValue], Generic[FormatOutputType], ABC
):
"""Base class for all prompt templates, returning a prompt."""
input_variables: list[str]
"""A list of the names of the variables whose values are required as inputs to the
prompt."""
optional_variables: list[str] = Field(default=[])
"""optional_variables: A list of the names of the variables for placeholder
or MessagePlaceholder that are optional. These variables are auto inferred
from the prompt and user need not provide them."""
input_types: typing.Dict[str, Any] = Field(default_factory=dict, exclude=True) # noqa: UP006
"""A dictionary of the types of the variables the prompt template expects.
If not provided, all variables are assumed to be strings."""
output_parser: Optional[BaseOutputParser] = None
"""How to parse the output of calling an LLM on this formatted prompt."""
partial_variables: Mapping[str, Any] = Field(default_factory=dict)
"""A dictionary of the partial variables the prompt template carries.
Partial variables populate the template so that you don't need to
pass them in every time you call the prompt."""
metadata: Optional[typing.Dict[str, Any]] = None # noqa: UP006
"""Metadata to be used for tracing."""
tags: Optional[list[str]] = None
"""Tags to be used for tracing."""
@model_validator(mode="after")
def validate_variable_names(self) -> Self:
"""Validate variable names do not include restricted names."""
if "stop" in self.input_variables:
msg = (
"Cannot have an input variable named 'stop', as it is used internally,"
" please rename."
)
raise ValueError(
create_message(message=msg, error_code=ErrorCode.INVALID_PROMPT_INPUT)
)
if "stop" in self.partial_variables:
msg = (
"Cannot have an partial variable named 'stop', as it is used "
"internally, please rename."
)
raise ValueError(
create_message(message=msg, error_code=ErrorCode.INVALID_PROMPT_INPUT)
)
overall = set(self.input_variables).intersection(self.partial_variables)
if overall:
msg = f"Found overlapping input and partial variables: {overall}"
raise ValueError(
create_message(message=msg, error_code=ErrorCode.INVALID_PROMPT_INPUT)
)
return self
@classmethod
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Returns ["langchain", "schema", "prompt_template"]."""
return ["langchain", "schema", "prompt_template"]
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable.
Returns True."""
return True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@cached_property
def _serialized(self) -> dict[str, Any]:
return dumpd(self)
@property
@override
def OutputType(self) -> Any:
"""Return the output type of the prompt."""
return Union[StringPromptValue, ChatPromptValueConcrete]
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> type[BaseModel]:
"""Get the input schema for the prompt.
Args:
config: RunnableConfig, configuration for the prompt.
Returns:
Type[BaseModel]: The input schema for the prompt.
"""
# This is correct, but pydantic typings/mypy don't think so.
required_input_variables = {
k: (self.input_types.get(k, str), ...) for k in self.input_variables
}
optional_input_variables = {
k: (self.input_types.get(k, str), None) for k in self.optional_variables
}
return create_model_v2(
"PromptInput",
field_definitions={**required_input_variables, **optional_input_variables},
)
def _validate_input(self, inner_input: Any) -> dict:
if not isinstance(inner_input, dict):
if len(self.input_variables) == 1:
var_name = self.input_variables[0]
inner_input = {var_name: inner_input}
else:
msg = (
f"Expected mapping type as input to {self.__class__.__name__}. "
f"Received {type(inner_input)}."
)
raise TypeError(
create_message(
message=msg, error_code=ErrorCode.INVALID_PROMPT_INPUT
)
)
missing = set(self.input_variables).difference(inner_input)
if missing:
msg = (
f"Input to {self.__class__.__name__} is missing variables {missing}. "
f" Expected: {self.input_variables}"
f" Received: {list(inner_input.keys())}"
)
example_key = missing.pop()
msg += (
f"\nNote: if you intended {{{example_key}}} to be part of the string"
" and not a variable, please escape it with double curly braces like: "
f"'{{{{{example_key}}}}}'."
)
raise KeyError(
create_message(message=msg, error_code=ErrorCode.INVALID_PROMPT_INPUT)
)
return inner_input
def _format_prompt_with_error_handling(self, inner_input: dict) -> PromptValue:
_inner_input = self._validate_input(inner_input)
return self.format_prompt(**_inner_input)
async def _aformat_prompt_with_error_handling(
self, inner_input: dict
) -> PromptValue:
_inner_input = self._validate_input(inner_input)
return await self.aformat_prompt(**_inner_input)
[docs]
def invoke(
self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> PromptValue:
"""Invoke the prompt.
Args:
input: Dict, input to the prompt.
config: RunnableConfig, configuration for the prompt.
Returns:
PromptValue: The output of the prompt.
"""
config = ensure_config(config)
if self.metadata:
config["metadata"] = {**config["metadata"], **self.metadata}
if self.tags:
config["tags"] = config["tags"] + self.tags
return self._call_with_config(
self._format_prompt_with_error_handling,
input,
config,
run_type="prompt",
serialized=self._serialized,
)
[docs]
async def ainvoke(
self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> PromptValue:
"""Async invoke the prompt.
Args:
input: Dict, input to the prompt.
config: RunnableConfig, configuration for the prompt.
Returns:
PromptValue: The output of the prompt.
"""
config = ensure_config(config)
if self.metadata:
config["metadata"].update(self.metadata)
if self.tags:
config["tags"].extend(self.tags)
return await self._acall_with_config(
self._aformat_prompt_with_error_handling,
input,
config,
run_type="prompt",
serialized=self._serialized,
)
[docs]
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
"""Return a partial of the prompt template.
Args:
kwargs: Union[str, Callable[[], str]], partial variables to set.
Returns:
BasePromptTemplate: A partial of the prompt template.
"""
prompt_dict = self.__dict__.copy()
prompt_dict["input_variables"] = list(
set(self.input_variables).difference(kwargs)
)
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
return type(self)(**prompt_dict)
def _merge_partial_and_user_variables(self, **kwargs: Any) -> dict[str, Any]:
# Get partial params:
partial_kwargs = {
k: v if not callable(v) else v() for k, v in self.partial_variables.items()
}
return {**partial_kwargs, **kwargs}
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
raise NotImplementedError
def dict(self, **kwargs: Any) -> dict:
"""Return dictionary representation of prompt.
Args:
kwargs: Any additional arguments to pass to the dictionary.
Returns:
Dict: Dictionary representation of the prompt.
Raises:
NotImplementedError: If the prompt type is not implemented.
"""
prompt_dict = super().model_dump(**kwargs)
with contextlib.suppress(NotImplementedError):
prompt_dict["_type"] = self._prompt_type
return prompt_dict
[docs]
def save(self, file_path: Union[Path, str]) -> None:
"""Save the prompt.
Args:
file_path: Path to directory to save prompt to.
Raises:
ValueError: If the prompt has partial variables.
ValueError: If the file path is not json or yaml.
NotImplementedError: If the prompt type is not implemented.
Example:
.. code-block:: python
prompt.save(file_path="path/prompt.yaml")
"""
if self.partial_variables:
msg = "Cannot save prompt with partial variables."
raise ValueError(msg)
# Fetch dictionary to save
prompt_dict = self.dict()
if "_type" not in prompt_dict:
msg = f"Prompt {self} does not support saving."
raise NotImplementedError(msg)
# Convert file to Path object.
save_path = Path(file_path) if isinstance(file_path, str) else file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(prompt_dict, f, indent=4)
elif save_path.suffix.endswith((".yaml", ".yml")):
with open(file_path, "w") as f:
yaml.dump(prompt_dict, f, default_flow_style=False)
else:
msg = f"{save_path} must be json or yaml"
raise ValueError(msg)
def _get_document_info(doc: Document, prompt: BasePromptTemplate[str]) -> dict:
base_info = {"page_content": doc.page_content, **doc.metadata}
missing_metadata = set(prompt.input_variables).difference(base_info)
if len(missing_metadata) > 0:
required_metadata = [
iv for iv in prompt.input_variables if iv != "page_content"
]
msg = (
f"Document prompt requires documents to have metadata variables: "
f"{required_metadata}. Received document with missing metadata: "
f"{list(missing_metadata)}."
)
raise ValueError(
create_message(message=msg, error_code=ErrorCode.INVALID_PROMPT_INPUT)
)
return {k: base_info[k] for k in prompt.input_variables}