Source code for langchain_community.tools.e2b_data_analysis.tool

from __future__ import annotations

import ast
import json
import os
from io import StringIO
from sys import version_info
from typing import IO, TYPE_CHECKING, Any, Callable, List, Optional, Type, Union

from langchain_core.callbacks import (
    AsyncCallbackManagerForToolRun,
    CallbackManager,
    CallbackManagerForToolRun,
)
from langchain_core.pydantic_v1 import BaseModel, Field, PrivateAttr

from langchain_community.tools import BaseTool, Tool
from langchain_community.tools.e2b_data_analysis.unparse import Unparser

if TYPE_CHECKING:
    from e2b import EnvVars
    from e2b.templates.data_analysis import Artifact

base_description = """Evaluates python code in a sandbox environment. \
The environment is long running and exists across multiple executions. \
You must send the whole script every time and print your outputs. \
Script should be pure python code that can be evaluated. \
It should be in python format NOT markdown. \
The code should NOT be wrapped in backticks. \
All python packages including requests, matplotlib, scipy, numpy, pandas, \
etc are available. Create and display chart using `plt.show()`."""


def _unparse(tree: ast.AST) -> str:
    """取消解析AST。"""
    if version_info.minor < 9:
        s = StringIO()
        Unparser(tree, file=s)
        source_code = s.getvalue()
        s.close()
    else:
        source_code = ast.unparse(tree)  # type: ignore[attr-defined]
    return source_code


[docs]def add_last_line_print(code: str) -> str: """如果缺少最后一行的打印语句,则添加打印语句。有时,LLM生成的代码没有`print(variable_name)`,而是尝试通过写入`variable_name`(例如在REPL中)来仅打印变量。该方法检查生成的Python代码的AST,并在缺少打印语句的情况下将打印语句添加到最后一行。 """ tree = ast.parse(code) node = tree.body[-1] if isinstance(node, ast.Expr) and isinstance(node.value, ast.Call): if isinstance(node.value.func, ast.Name) and node.value.func.id == "print": return _unparse(tree) if isinstance(node, ast.Expr): tree.body[-1] = ast.Expr( value=ast.Call( func=ast.Name(id="print", ctx=ast.Load()), args=[node.value], keywords=[], ) ) return _unparse(tree)
[docs]class UploadedFile(BaseModel): """描述已上传路径及其远程路径。""" name: str remote_path: str description: str
[docs]class E2BDataAnalysisToolArguments(BaseModel): """E2BDataAnalysisTool的参数。""" python_code: str = Field( ..., example="print('Hello World')", description=( "The python script to be evaluated. " "The contents will be in main.py. " "It should not be in markdown format." ), )
[docs]class E2BDataAnalysisTool(BaseTool): """用于在数据分析的沙盒环境中运行Python代码的工具。""" name = "e2b_data_analysis" args_schema: Type[BaseModel] = E2BDataAnalysisToolArguments session: Any description: str _uploaded_files: List[UploadedFile] = PrivateAttr(default_factory=list) def __init__( self, api_key: Optional[str] = None, cwd: Optional[str] = None, env_vars: Optional[EnvVars] = None, on_stdout: Optional[Callable[[str], Any]] = None, on_stderr: Optional[Callable[[str], Any]] = None, on_artifact: Optional[Callable[[Artifact], Any]] = None, on_exit: Optional[Callable[[int], Any]] = None, **kwargs: Any, ): try: from e2b import DataAnalysis except ImportError as e: raise ImportError( "Unable to import e2b, please install with `pip install e2b`." ) from e # If no API key is provided, E2B will try to read it from the environment # variable E2B_API_KEY super().__init__(description=base_description, **kwargs) self.session = DataAnalysis( api_key=api_key, cwd=cwd, env_vars=env_vars, on_stdout=on_stdout, on_stderr=on_stderr, on_exit=on_exit, on_artifact=on_artifact, )
[docs] def close(self) -> None: """关闭云沙盒。""" self._uploaded_files = [] self.session.close()
@property def uploaded_files_description(self) -> str: if len(self._uploaded_files) == 0: return "" lines = ["The following files available in the sandbox:"] for f in self._uploaded_files: if f.description == "": lines.append(f"- path: `{f.remote_path}`") else: lines.append( f"- path: `{f.remote_path}` \n description: `{f.description}`" ) return "\n".join(lines) def _run( self, python_code: str, run_manager: Optional[CallbackManagerForToolRun] = None, callbacks: Optional[CallbackManager] = None, ) -> str: python_code = add_last_line_print(python_code) if callbacks is not None: on_artifact = getattr(callbacks.metadata, "on_artifact", None) else: on_artifact = None stdout, stderr, artifacts = self.session.run_python( python_code, on_artifact=on_artifact ) out = { "stdout": stdout, "stderr": stderr, "artifacts": list(map(lambda artifact: artifact.name, artifacts)), } return json.dumps(out) async def _arun( self, python_code: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> str: raise NotImplementedError("e2b_data_analysis does not support async")
[docs] def run_command( self, cmd: str, ) -> dict: """在沙盒中运行 shell 命令。""" proc = self.session.process.start(cmd) output = proc.wait() return { "stdout": output.stdout, "stderr": output.stderr, "exit_code": output.exit_code, }
[docs] def install_python_packages(self, package_names: Union[str, List[str]]) -> None: """在沙盒中安装Python包。""" self.session.install_python_packages(package_names)
[docs] def install_system_packages(self, package_names: Union[str, List[str]]) -> None: """在沙盒中安装系统软件包(通过apt)。""" self.session.install_system_packages(package_names)
[docs] def download_file(self, remote_path: str) -> bytes: """从沙盒下载文件。""" return self.session.download_file(remote_path)
[docs] def upload_file(self, file: IO, description: str) -> UploadedFile: """将文件上传到沙盒。 文件将上传到'/home/user/<filename>'路径。 """ remote_path = self.session.upload_file(file) f = UploadedFile( name=os.path.basename(file.name), remote_path=remote_path, description=description, ) self._uploaded_files.append(f) self.description = self.description + "\n" + self.uploaded_files_description return f
[docs] def remove_uploaded_file(self, uploaded_file: UploadedFile) -> None: """从沙盒中删除已上传的文件。""" self.session.filesystem.remove(uploaded_file.remote_path) self._uploaded_files = [ f for f in self._uploaded_files if f.remote_path != uploaded_file.remote_path ] self.description = self.description + "\n" + self.uploaded_files_description
[docs] def as_tool(self) -> Tool: return Tool.from_function( func=self._run, name=self.name, description=self.description, args_schema=self.args_schema, )