class LMFormatEnforcerPydanticProgram(BaseLLMFunctionProgram):
"""一个基于lm-format-enforcer的函数,返回一个pydantic模型。
在LMFormatEnforcerPydanticProgram中,prompt_template_str也可以有一个{json_schema}参数,该参数将自动由output_cls的json_schema填充。
注意:此接口尚不稳定。"""
def __init__(
self,
output_cls: Type[BaseModel],
prompt_template_str: str,
llm: Optional[Union[LlamaCPP, HuggingFaceLLM]] = None,
verbose: bool = False,
):
try:
import lmformatenforcer
except ImportError as e:
raise ImportError(
"lm-format-enforcer package not found."
"please run `pip install lm-format-enforcer`"
) from e
if llm is None:
try:
from llama_index.core.llms import LlamaCPP
llm = LlamaCPP()
except ImportError as e:
raise ImportError(
"llama.cpp package not found."
"please run `pip install llama-cpp-python`"
) from e
self.llm = llm
self._prompt_template_str = prompt_template_str
self._output_cls = output_cls
self._verbose = verbose
json_schema_parser = lmformatenforcer.JsonSchemaParser(self.output_cls.schema())
self._token_enforcer_fn = build_lm_format_enforcer_function(
self.llm, json_schema_parser
)
@classmethod
def from_defaults(
cls,
output_cls: Type[BaseModel],
prompt_template_str: Optional[str] = None,
prompt: Optional[PromptTemplate] = None,
llm: Optional[Union["LlamaCPP", "HuggingFaceLLM"]] = None,
**kwargs: Any,
) -> "BaseLLMFunctionProgram":
"""从默认值。"""
if prompt is None and prompt_template_str is None:
raise ValueError("Must provide either prompt or prompt_template_str.")
if prompt is not None and prompt_template_str is not None:
raise ValueError("Must provide either prompt or prompt_template_str.")
if prompt is not None:
prompt_template_str = prompt.template
prompt_template_str = cast(str, prompt_template_str)
return cls(
output_cls,
prompt_template_str,
llm=llm,
**kwargs,
)
@property
def output_cls(self) -> Type[BaseModel]:
return self._output_cls
def __call__(
self,
llm_kwargs: Optional[Dict[str, Any]] = None,
*args: Any,
**kwargs: Any,
) -> BaseModel:
llm_kwargs = llm_kwargs or {}
# While the format enforcer is active, any calls to the llm will have the format enforced.
with activate_lm_format_enforcer(self.llm, self._token_enforcer_fn):
json_schema_str = json.dumps(self.output_cls.schema())
full_str = self._prompt_template_str.format(
*args, **kwargs, json_schema=json_schema_str
)
output = self.llm.complete(full_str, **llm_kwargs)
text = output.text
return self.output_cls.parse_raw(text)