"""实现了程序辅助语言模型。
该模块实现了用于生成代码解决方案的程序辅助语言模型(PAL)。PAL是一种在论文“程序辅助语言模型”(https://arxiv.org/pdf/2211.10435.pdf)中描述的技术。
"""
from __future__ import annotations
import ast
from typing import Any, Dict, List, Optional
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain_community.utilities import PythonREPL
from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_experimental.pal_chain.colored_object_prompt import COLORED_OBJECT_PROMPT
from langchain_experimental.pal_chain.math_prompt import MATH_PROMPT
from langchain_experimental.pydantic_v1 import Extra, Field
COMMAND_EXECUTION_FUNCTIONS = ["system", "exec", "execfile", "eval", "__import__"]
COMMAND_EXECUTION_ATTRIBUTES = [
"__import__",
"__subclasses__",
"__builtins__",
"__globals__",
"__getattribute__",
"__bases__",
"__mro__",
"__base__",
]
[docs]class PALValidation:
"""用于PAL生成代码的验证。"""
SOLUTION_EXPRESSION_TYPE_FUNCTION = ast.FunctionDef
SOLUTION_EXPRESSION_TYPE_VARIABLE = ast.Name
[docs] def __init__(
self,
solution_expression_name: Optional[str] = None,
solution_expression_type: Optional[type] = None,
allow_imports: bool = False,
allow_command_exec: bool = False,
):
"""初始化一个PALValidation实例。
参数:
solution_expression_name(str):预期解决方案表达式的名称。
如果传递了,那么必须同时传递solution_expression_type。
solution_expression_type(str):预期解决方案表达式的AST类型。
如果传递了,那么必须同时传递solution_expression_name。
必须是PALValidation.SOLUTION_EXPRESSION_TYPE_FUNCTION或
PALValidation.SOLUTION_EXPRESSION_TYPE_VARIABLE之一。
allow_imports(bool):允许导入语句。
allow_command_exec(bool):允许使用已知的命令执行函数。
"""
self.solution_expression_name = solution_expression_name
self.solution_expression_type = solution_expression_type
if solution_expression_name is not None:
if not isinstance(self.solution_expression_name, str):
raise ValueError(
f"Expected solution_expression_name to be str, "
f"instead found {type(self.solution_expression_name)}"
)
if solution_expression_type is not None:
if (
self.solution_expression_type
is not self.SOLUTION_EXPRESSION_TYPE_FUNCTION
and self.solution_expression_type
is not self.SOLUTION_EXPRESSION_TYPE_VARIABLE
):
raise ValueError(
f"Expected solution_expression_type to be one of "
f"({self.SOLUTION_EXPRESSION_TYPE_FUNCTION},"
f"{self.SOLUTION_EXPRESSION_TYPE_VARIABLE}),"
f"instead found {self.solution_expression_type}"
)
if solution_expression_name is not None and solution_expression_type is None:
raise TypeError(
"solution_expression_name "
"requires solution_expression_type to be passed as well"
)
if solution_expression_name is None and solution_expression_type is not None:
raise TypeError(
"solution_expression_type "
"requires solution_expression_name to be passed as well"
)
self.allow_imports = allow_imports
self.allow_command_exec = allow_command_exec
[docs]class PALChain(Chain):
"""实现了程序辅助语言模型(PAL)的链。
该类实现了程序辅助语言模型(PAL),用于生成代码解决方案。PAL是一种在论文“Program-Aided Language Models”中描述的技术(https://arxiv.org/pdf/2211.10435.pdf)。
*安全提示*:该类实现了一种生成和评估Python代码的AI技术,可能存在危险,需要在安全环境中使用。虽然该类通过限制可用的局部变量/全局变量、解析和检查生成的Python AST(抽象语法树)使用`PALValidation`来实现一些基本的防护措施,但这些防护措施无法阻止复杂的攻击者,并且不能替代适当的沙盒环境。请勿在不受信任的输入、具有提升的权限或未经安全团队咨询的情况下使用该类!
"""
llm_chain: LLMChain
stop: str = "\n\n"
"""生成代码时使用的停止标记。"""
get_answer_expr: str = "print(solution())"
"""从生成的代码中获取答案的表达式。"""
python_globals: Optional[Dict[str, Any]] = None
"""Python全局变量和局部变量在执行生成的代码时使用。"""
python_locals: Optional[Dict[str, Any]] = None
"""Python全局变量和局部变量在执行生成的代码时使用。"""
output_key: str = "result" #: :meta private:
return_intermediate_steps: bool = False
"""生成的代码中是否返回中间步骤。"""
code_validations: PALValidation = Field(default_factory=PALValidation)
"""生成的代码上需要执行的验证。"""
timeout: Optional[int] = 10
"""生成的代码执行超时时间(秒)。"""
class Config:
"""此为pydantic对象的配置。"""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""返回单个输入键。
:元数据 私有:
"""
return self.llm_chain.prompt.input_variables
@property
def output_keys(self) -> List[str]:
"""返回单一的输出键。
:元数据 私有:
"""
if not self.return_intermediate_steps:
return [self.output_key]
else:
return [self.output_key, "intermediate_steps"]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
code = self.llm_chain.predict(
stop=[self.stop], callbacks=_run_manager.get_child(), **inputs
)
_run_manager.on_text(code, color="green", end="\n", verbose=self.verbose)
PALChain.validate_code(code, self.code_validations)
# TODO: look into why mypy thinks PythonREPL's type here is `Any`
# and therefore not callable
repl = PythonREPL(
_globals=self.python_globals,
_locals=self.python_locals,
) # type: ignore[misc]
res = repl.run(code + f"\n{self.get_answer_expr}", timeout=self.timeout)
output = {self.output_key: res.strip()}
if self.return_intermediate_steps:
output["intermediate_steps"] = code
return output
[docs] @classmethod
def validate_code(cls, code: str, code_validations: PALValidation) -> None:
try:
code_tree = ast.parse(code)
except (SyntaxError, UnicodeDecodeError):
raise ValueError(f"Generated code is not valid python code: {code}")
except TypeError:
raise ValueError(
f"Generated code is expected to be a string, "
f"instead found {type(code)}"
)
except OverflowError:
raise ValueError(
f"Generated code too long / complex to be parsed by ast: {code}"
)
found_solution_expr = False
if code_validations.solution_expression_name is None:
# Skip validation if no solution_expression_name was given
found_solution_expr = True
has_imports = False
top_level_nodes = list(ast.iter_child_nodes(code_tree))
for node in top_level_nodes:
if (
code_validations.solution_expression_name is not None
and code_validations.solution_expression_type is not None
):
# Check root nodes (like func def)
if (
isinstance(node, code_validations.solution_expression_type)
and hasattr(node, "name")
and node.name == code_validations.solution_expression_name
):
found_solution_expr = True
# Check assigned nodes (like answer variable)
if isinstance(node, ast.Assign):
for target_node in node.targets:
if (
isinstance(
target_node, code_validations.solution_expression_type
)
and hasattr(target_node, "id")
and target_node.id
== code_validations.solution_expression_name
):
found_solution_expr = True
if isinstance(node, ast.Import) or isinstance(node, ast.ImportFrom):
has_imports = True
if not found_solution_expr:
raise ValueError(
f"Generated code is missing the solution expression: "
f"{code_validations.solution_expression_name} of type: "
f"{code_validations.solution_expression_type}"
)
if not code_validations.allow_imports and has_imports:
raise ValueError(f"Generated code has disallowed imports: {code}")
if (
not code_validations.allow_command_exec
or not code_validations.allow_imports
):
for node in ast.walk(code_tree):
if (
not code_validations.allow_command_exec
and isinstance(node, ast.Attribute)
and node.attr in COMMAND_EXECUTION_ATTRIBUTES
):
raise ValueError(
f"Found illegal command execution function "
f"{node.attr} in code {code}"
)
if (not code_validations.allow_command_exec) and isinstance(
node, ast.Call
):
if (
hasattr(node.func, "id")
and node.func.id in COMMAND_EXECUTION_FUNCTIONS
):
raise ValueError(
f"Found illegal command execution function "
f"{node.func.id} in code {code}"
)
if (
isinstance(node.func, ast.Attribute)
and node.func.attr in COMMAND_EXECUTION_FUNCTIONS
):
raise ValueError(
f"Found illegal command execution function "
f"{node.func.attr} in code {code}"
)
if (not code_validations.allow_imports) and (
isinstance(node, ast.Import) or isinstance(node, ast.ImportFrom)
):
raise ValueError(f"Generated code has disallowed imports: {code}")
[docs] @classmethod
def from_math_prompt(cls, llm: BaseLanguageModel, **kwargs: Any) -> PALChain:
"""从数学提示中加载PAL。
参数:
llm(BaseLanguageModel):用于生成代码的语言模型。
返回:
PALChain:PALChain的一个实例。
"""
llm_chain = LLMChain(llm=llm, prompt=MATH_PROMPT)
code_validations = PALValidation(
solution_expression_name="solution",
solution_expression_type=PALValidation.SOLUTION_EXPRESSION_TYPE_FUNCTION,
)
return cls(
llm_chain=llm_chain,
stop="\n\n",
get_answer_expr="print(solution())",
code_validations=code_validations,
**kwargs,
)
[docs] @classmethod
def from_colored_object_prompt(
cls, llm: BaseLanguageModel, **kwargs: Any
) -> PALChain:
"""从彩色对象提示加载PAL。
参数:
llm(BaseLanguageModel):用于生成代码的语言模型。
返回:
PALChain:PALChain的一个实例。
"""
llm_chain = LLMChain(llm=llm, prompt=COLORED_OBJECT_PROMPT)
code_validations = PALValidation(
solution_expression_name="answer",
solution_expression_type=PALValidation.SOLUTION_EXPRESSION_TYPE_VARIABLE,
)
return cls(
llm_chain=llm_chain,
stop="\n\n\n",
get_answer_expr="print(answer)",
code_validations=code_validations,
**kwargs,
)
@property
def _chain_type(self) -> str:
return "pal_chain"