Source code for langchain_experimental.tools.python.tool

"""一个在REPL中运行Python代码的工具。"""

import ast
import re
import sys
from contextlib import redirect_stdout
from io import StringIO
from typing import Any, Dict, Optional, Type

from langchain.pydantic_v1 import BaseModel, Field, root_validator
from langchain.tools.base import BaseTool
from langchain_core.callbacks.manager import (
    AsyncCallbackManagerForToolRun,
    CallbackManagerForToolRun,
)
from langchain_core.runnables.config import run_in_executor

from langchain_experimental.utilities.python import PythonREPL


def _get_default_python_repl() -> PythonREPL:
    return PythonREPL(_globals=globals(), _locals=None)


[docs]def sanitize_input(query: str) -> str: """清理输入到Python REPL的内容。 移除空白、反引号和python(如果llm错误地将python控制台视为终端) 参数: query: 需要清理的查询内容 返回: str: 清理后的查询内容 """ # 移除开头的逗号、空格和python query = re.sub(r"^(\s|`)*(?i:python)?\s*", "", query) # 去除末尾的空格和`符号 query = re.sub(r"(\s|`)*$", "", query) return query
[docs]class PythonREPLTool(BaseTool): """用于在REPL中运行Python代码的工具。""" name: str = "Python_REPL" description: str = ( "A Python shell. Use this to execute python commands. " "Input should be a valid python command. " "If you want to see the output of a value, you should print it out " "with `print(...)`." ) python_repl: PythonREPL = Field(default_factory=_get_default_python_repl) sanitize_input: bool = True def _run( self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None, ) -> Any: """使用这个工具。""" if self.sanitize_input: query = sanitize_input(query) return self.python_repl.run(query) async def _arun( self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> Any: """异步使用该工具。""" if self.sanitize_input: query = sanitize_input(query) return await run_in_executor(None, self.run, query)
[docs]class PythonInputs(BaseModel): """Python 输入.""" query: str = Field(description="code snippet to run")
[docs]class PythonAstREPLTool(BaseTool): """用于在REPL中运行Python代码的工具。""" name: str = "python_repl_ast" description: str = ( "A Python shell. Use this to execute python commands. " "Input should be a valid python command. " "When using this tool, sometimes output is abbreviated - " "make sure it does not look abbreviated before using it in your answer." ) globals: Optional[Dict] = Field(default_factory=dict) locals: Optional[Dict] = Field(default_factory=dict) sanitize_input: bool = True args_schema: Type[BaseModel] = PythonInputs @root_validator(pre=True) def validate_python_version(cls, values: Dict) -> Dict: """验证有效的Python版本。""" if sys.version_info < (3, 9): raise ValueError( "This tool relies on Python 3.9 or higher " "(as it uses new functionality in the `ast` module, " f"you have Python version: {sys.version}" ) return values def _run( self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None, ) -> str: """使用这个工具。""" try: if self.sanitize_input: query = sanitize_input(query) tree = ast.parse(query) module = ast.Module(tree.body[:-1], type_ignores=[]) exec(ast.unparse(module), self.globals, self.locals) # 类型:忽略 module_end = ast.Module(tree.body[-1:], type_ignores=[]) module_end_str = ast.unparse(module_end) # 类型:忽略 io_buffer = StringIO() try: with redirect_stdout(io_buffer): ret = eval(module_end_str, self.globals, self.locals) if ret is None: return io_buffer.getvalue() else: return ret except Exception: with redirect_stdout(io_buffer): exec(module_end_str, self.globals, self.locals) return io_buffer.getvalue() except Exception as e: return "{}: {}".format(type(e).__name__, str(e)) async def _arun( self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> Any: """异步使用该工具。""" return await run_in_executor(None, self._run, query)