Source code for langchain_community.tools.e2b_data_analysis.tool
from __future__ import annotations
import ast
import json
import os
from io import StringIO
from sys import version_info
from typing import IO, TYPE_CHECKING, Any, Callable, List, Optional, Type, Union
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManager,
CallbackManagerForToolRun,
)
from langchain_core.pydantic_v1 import BaseModel, Field, PrivateAttr
from langchain_community.tools import BaseTool, Tool
from langchain_community.tools.e2b_data_analysis.unparse import Unparser
if TYPE_CHECKING:
from e2b import EnvVars
from e2b.templates.data_analysis import Artifact
base_description = """Evaluates python code in a sandbox environment. \
The environment is long running and exists across multiple executions. \
You must send the whole script every time and print your outputs. \
Script should be pure python code that can be evaluated. \
It should be in python format NOT markdown. \
The code should NOT be wrapped in backticks. \
All python packages including requests, matplotlib, scipy, numpy, pandas, \
etc are available. Create and display chart using `plt.show()`."""
def _unparse(tree: ast.AST) -> str:
"""取消解析AST。"""
if version_info.minor < 9:
s = StringIO()
Unparser(tree, file=s)
source_code = s.getvalue()
s.close()
else:
source_code = ast.unparse(tree) # type: ignore[attr-defined]
return source_code
[docs]def add_last_line_print(code: str) -> str:
"""如果缺少最后一行的打印语句,则添加打印语句。有时,LLM生成的代码没有`print(variable_name)`,而是尝试通过写入`variable_name`(例如在REPL中)来仅打印变量。该方法检查生成的Python代码的AST,并在缺少打印语句的情况下将打印语句添加到最后一行。
"""
tree = ast.parse(code)
node = tree.body[-1]
if isinstance(node, ast.Expr) and isinstance(node.value, ast.Call):
if isinstance(node.value.func, ast.Name) and node.value.func.id == "print":
return _unparse(tree)
if isinstance(node, ast.Expr):
tree.body[-1] = ast.Expr(
value=ast.Call(
func=ast.Name(id="print", ctx=ast.Load()),
args=[node.value],
keywords=[],
)
)
return _unparse(tree)
[docs]class UploadedFile(BaseModel):
"""描述已上传路径及其远程路径。"""
name: str
remote_path: str
description: str
[docs]class E2BDataAnalysisToolArguments(BaseModel):
"""E2BDataAnalysisTool的参数。"""
python_code: str = Field(
...,
example="print('Hello World')",
description=(
"The python script to be evaluated. "
"The contents will be in main.py. "
"It should not be in markdown format."
),
)
[docs]class E2BDataAnalysisTool(BaseTool):
"""用于在数据分析的沙盒环境中运行Python代码的工具。"""
name = "e2b_data_analysis"
args_schema: Type[BaseModel] = E2BDataAnalysisToolArguments
session: Any
description: str
_uploaded_files: List[UploadedFile] = PrivateAttr(default_factory=list)
def __init__(
self,
api_key: Optional[str] = None,
cwd: Optional[str] = None,
env_vars: Optional[EnvVars] = None,
on_stdout: Optional[Callable[[str], Any]] = None,
on_stderr: Optional[Callable[[str], Any]] = None,
on_artifact: Optional[Callable[[Artifact], Any]] = None,
on_exit: Optional[Callable[[int], Any]] = None,
**kwargs: Any,
):
try:
from e2b import DataAnalysis
except ImportError as e:
raise ImportError(
"Unable to import e2b, please install with `pip install e2b`."
) from e
# If no API key is provided, E2B will try to read it from the environment
# variable E2B_API_KEY
super().__init__(description=base_description, **kwargs)
self.session = DataAnalysis(
api_key=api_key,
cwd=cwd,
env_vars=env_vars,
on_stdout=on_stdout,
on_stderr=on_stderr,
on_exit=on_exit,
on_artifact=on_artifact,
)
@property
def uploaded_files_description(self) -> str:
if len(self._uploaded_files) == 0:
return ""
lines = ["The following files available in the sandbox:"]
for f in self._uploaded_files:
if f.description == "":
lines.append(f"- path: `{f.remote_path}`")
else:
lines.append(
f"- path: `{f.remote_path}` \n description: `{f.description}`"
)
return "\n".join(lines)
def _run(
self,
python_code: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
callbacks: Optional[CallbackManager] = None,
) -> str:
python_code = add_last_line_print(python_code)
if callbacks is not None:
on_artifact = getattr(callbacks.metadata, "on_artifact", None)
else:
on_artifact = None
stdout, stderr, artifacts = self.session.run_python(
python_code, on_artifact=on_artifact
)
out = {
"stdout": stdout,
"stderr": stderr,
"artifacts": list(map(lambda artifact: artifact.name, artifacts)),
}
return json.dumps(out)
async def _arun(
self,
python_code: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
raise NotImplementedError("e2b_data_analysis does not support async")
[docs] def run_command(
self,
cmd: str,
) -> dict:
"""在沙盒中运行 shell 命令。"""
proc = self.session.process.start(cmd)
output = proc.wait()
return {
"stdout": output.stdout,
"stderr": output.stderr,
"exit_code": output.exit_code,
}
[docs] def install_python_packages(self, package_names: Union[str, List[str]]) -> None:
"""在沙盒中安装Python包。"""
self.session.install_python_packages(package_names)
[docs] def install_system_packages(self, package_names: Union[str, List[str]]) -> None:
"""在沙盒中安装系统软件包(通过apt)。"""
self.session.install_system_packages(package_names)
[docs] def download_file(self, remote_path: str) -> bytes:
"""从沙盒下载文件。"""
return self.session.download_file(remote_path)
[docs] def upload_file(self, file: IO, description: str) -> UploadedFile:
"""将文件上传到沙盒。
文件将上传到'/home/user/<filename>'路径。
"""
remote_path = self.session.upload_file(file)
f = UploadedFile(
name=os.path.basename(file.name),
remote_path=remote_path,
description=description,
)
self._uploaded_files.append(f)
self.description = self.description + "\n" + self.uploaded_files_description
return f
[docs] def remove_uploaded_file(self, uploaded_file: UploadedFile) -> None:
"""从沙盒中删除已上传的文件。"""
self.session.filesystem.remove(uploaded_file.remote_path)
self._uploaded_files = [
f
for f in self._uploaded_files
if f.remote_path != uploaded_file.remote_path
]
self.description = self.description + "\n" + self.uploaded_files_description
[docs] def as_tool(self) -> Tool:
return Tool.from_function(
func=self._run,
name=self.name,
description=self.description,
args_schema=self.args_schema,
)