Skip to main content

自定义代码执行器

在 Colab 中打开 在 GitHub 上打开

在本指南中,我们将展示如何创建一个自定义的代码执行器,该执行器可以在与当前 Jupyter 笔记本相同的环境中运行代码。

首先,让我们安装所需的依赖项:

! pip -qqq install pyautogen matplotlib yfinance
import os
from typing import List

from IPython import get_ipython

from autogen import ConversableAgent
from autogen.coding import CodeBlock, CodeExecutor, CodeExtractor, CodeResult, MarkdownCodeExtractor

现在,我们可以通过继承 CodeExecutor 协议并实现 execute_code_blocks 方法来创建自定义的代码执行器类。

class NotebookExecutor(CodeExecutor):

@property
def code_extractor(self) -> CodeExtractor:
# 从 markdown 块中提取代码。
return MarkdownCodeExtractor()

def __init__(self) -> None:
# 获取当前在此笔记本中运行的 IPython 实例。
self._ipython = get_ipython()

def execute_code_blocks(self, code_blocks: List[CodeBlock]) -> CodeResult:
log = ""
for code_block in code_blocks:
result = self._ipython.run_cell("%%capture --no-display cap\n" + code_block.code)
log += self._ipython.ev("cap.stdout")
log += self._ipython.ev("cap.stderr")
if result.result is not None:
log += str(result.result)
exitcode = 0 if result.success else 1
if result.error_before_exec is not None:
log += f"\n{result.error_before_exec}"
exitcode = 1
if result.error_in_exec is not None:
log += f"\n{result.error_in_exec}"
exitcode = 1
if exitcode != 0:
break
return CodeResult(exit_code=exitcode, output=log)

现在我们可以在我们的代理中使用新的自定义代码执行器了。

code_writer_agent = ConversableAgent(
name="CodeWriter",
system_message="你是一个有用的AI助手。\n"
"你使用编码技能来解决问题。\n"
"你可以访问一个IPython内核来执行Python代码。\n"
"你可以在Markdown块中建议Python代码,每个块都是一个单元格。\n"
"代码块将按照你建议的顺序在IPython内核中执行。\n"
"所有必要的库已经安装好。\n"
"任务完成后,返回 'TERMINATE'。",
llm_config={"config_list": [{"model": "gpt-4", "api_key": os.getenv("OPENAI_API_KEY")}]},
)

code_executor_agent = ConversableAgent(
name="CodeExecutor",
llm_config=False,
code_execution_config={"executor": NotebookExecutor()},
is_termination_msg=lambda msg: "TERMINATE" in msg.get("content", "").strip().upper(),
)

让我们使用代理来完成一个简单的任务,绘制一个显示前7家上市公司市值的图表。

chat_result = code_executor_agent.initiate_chat(
code_writer_agent,
message="使用来自 Yahoo Finance 的数据创建一个显示前7家上市公司市值的图表。",
)
CodeExecutor (to CodeWriter):

使用来自 Yahoo Finance 的数据创建一个显示前7家上市公司市值的图表。

--------------------------------------------------------------------------------

>>>>>>>> 使用自动回复...
CodeWriter (to CodeExecutor):

为了完成这个任务,我们将使用 `yfinance` 库从 Yahoo Finance 获取数据,使用 `pandas` 库进行数据处理,使用 `matplotlib` 进行数据可视化。

步骤:
1. 确定前7家上市公司的股票代码。目前,这些公司是:苹果 (AAPL),微软 (MSFT),谷歌 (GOOGL),亚马逊 (AMZN),Facebook (FB),特斯拉 (TSLA) 和伯克希尔·哈撒韦 (BRK-A)。
2. 使用 yfinance 获取这些公司的市值信息。
3. 清洗和处理获取的数据,使其成为可用的格式(一个 pandas DataFrame)。
4. 绘制这些公司的市值图表。

让我们从获取这些公司的数据开始。

--------------------------------------------------------------------------------

>>>>>>>> 使用自动回复...
CodeExecutor (to CodeWriter):



--------------------------------------------------------------------------------

>>>>>>>> 使用自动回复...
CodeWriter (to CodeExecutor):

太好了!在继续之前,我需要知道是否已安装所需的库。

这个任务需要以下库:
1. `yfinance`
2. `pandas`
3. `matplotlib`

如果这些库没有安装,你可以使用 pip 安装它们:

```python
!pip install yfinance pandas matplotlib
```

假设这些库已经安装,我们将导入它们并继续获取市值数据:

```python
import yfinance as yf
import pandas as pd
import matplotlib.pyplot as plt

# 定义股票代码
tickers = ["AAPL", "MSFT", "GOOGL", "AMZN", "FB", "TSLA", "BRK-A"]

# 获取数据
data = yf.download(tickers, start="2022-02-01", end="2022-02-28")

# 提取每个股票的 'Close' 值
data = data['Close']

# 创建空字典以保存市值数据
market_caps = {}

# 计算市值
for ticker in tickers:
info = yf.Ticker(ticker).info
market_caps[ticker] = info["marketCap"]

# 将 market_caps 字典转换为 pandas DataFrame
df = pd.DataFrame(list(market_caps.items()), columns=['Company', 'Market_Cap'])

# 按照 Market_Cap 降序排序 DataFrame
df = df.sort_values('Market_Cap', ascending=False)

# 绘制数据
plt.figure(figsize=(10,6))
plt.barh(df['Company'], df['Market_Cap'], color='skyblue')
plt.xlabel('市值(以万亿为单位)')
plt.title('前7家上市公司市值')
plt.gca().invert_yaxis()
plt.show()
请注意,在获取数据时使用的开始和结束日期指定了我们感兴趣的时间段。如有需要,请随意修改这些日期。从Ticker对象的'info'属性中获取的'marketCap'表示截止到结束日期的市值。

还请注意,最终的图表按市值降序排列,市值最高的公司位于图表的顶部。

--------------------------------------------------------------------------------

>>>>>>>> 使用自动回复...

>>>>>>>> 执行2个代码块(推测的语言为[python, python])...
``` text
KeyError: 'marketCap'
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Cell In[24], line 20
18 for ticker in tickers:
19 info = yf.Ticker(ticker).info
---> 20 market_caps[ticker] = info["marketCap"]
22 # 将 market_caps 字典转换为 pandas DataFrame
23 df = pd.DataFrame(list(market_caps.items()), columns=['公司', '市值'])

KeyError: 'marketCap'
```
```` text
CodeExecutor (to CodeWriter):

exitcode: 0 (execution succeeded)
Code output: Requirement already satisfied: yfinance in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (0.2.36)
Requirement already satisfied: pandas in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (2.2.1)
Requirement already satisfied: matplotlib in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (3.8.3)
Requirement already satisfied: numpy>=1.16.5 in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (from yfinance) (1.26.4)
Requirement already satisfied: requests>=2.31 in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (from yfinance) (2.31.0)
Requirement already satisfied: multitasking>=0.0.7 in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (from yfinance) (0.0.11)
Requirement already satisfied: lxml>=4.9.1 in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (from yfinance) (5.0.1)
Requirement already satisfied: appdirs>=1.4.4 in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (from yfinance) (1.4.4)
Requirement already satisfied: pytz>=2022.5 in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (from yfinance) (2023.3.post1)
Requirement already satisfied: frozendict>=2.3.4 in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (from yfinance) (2.4.0)
Requirement already satisfied: peewee>=3.16.2 in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (from yfinance) (3.17.0)
Requirement already satisfied: beautifulsoup4>=4.11.1 in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (from yfinance) (4.12.2)
Requirement already satisfied: html5lib>=1.1 in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (from yfinance) (1.1)
Requirement already satisfied: python-dateutil>=2.8.2 in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (from pandas) (2.8.2)
Requirement already satisfied: tzdata>=2022.7 in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (from pandas) (2023.4)
Requirement already satisfied: contourpy>=1.0.1 in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (from matplotlib) (1.2.0)
Requirement already satisfied: cycler>=0.10 in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (from matplotlib) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (from matplotlib) (4.47.2)
Requirement already satisfied: kiwisolver>=1.3.1 in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (from matplotlib) (1.4.5)
Requirement already satisfied: packaging>=20.0 in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (from matplotlib) (23.2)
Requirement already satisfied: pillow>=8 in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (from matplotlib) (10.2.0)
Requirement already satisfied: pyparsing>=2.3.1 in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (from matplotlib) (3.1.1)
Requirement already satisfied: soupsieve>1.2 in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (from beautifulsoup4>=4.11.1->yfinance) (2.5)
Requirement already satisfied: six>=1.9 in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (from html5lib>=1.1->yfinance) (1.16.0)
Requirement already satisfied: webencodings in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (from html5lib>=1.1->yfinance) (0.5.1)
Requirement already satisfied: charset-normalizer<4,>=2 in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (from requests>=2.31->yfinance) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (from requests>=2.31->yfinance) (3.6)
Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (from requests>=2.31->yfinance) (1.26.18)
Requirement already satisfied: certifi>=2017.4.17 in /Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages (from requests>=2.31->yfinance) (2024.2.2)
/Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages/yfinance/utils.py:775: FutureWarning: The 'unit' keyword in TimedeltaIndex construction is deprecated and will be removed in a future version. Use pd.to_timedelta instead.
df.index += _pd.TimedeltaIndex(dst_error_hours, 'h')
/Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages/yfinance/utils.py:775: FutureWarning: The 'unit' keyword in TimedeltaIndex construction is deprecated and will be removed in a future version. Use pd.to_timedelta instead.
df.index += _pd.TimedeltaIndex(dst_error_hours, 'h')
[************** 29%% ] 2 of 7 completed/Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages/yfinance/utils.py:775: FutureWarning: The 'unit' keyword in TimedeltaIndex construction is deprecated and will be removed in a future version. Use pd.to_timedelta instead.
df.index += _pd.TimedeltaIndex(dst_error_hours, 'h')
[********************* 43%% ] 3 of 7 completed/Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages/yfinance/utils.py:775: FutureWarning: The 'unit' keyword in TimedeltaIndex construction is deprecated and will be removed in a future version. Use pd.to_timedelta instead.
df.index += _pd.TimedeltaIndex(dst_error_hours, 'h')
[**********************57%%* ] 4 of 7 completed/Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages/yfinance/utils.py:775: FutureWarning: The 'unit' keyword in TimedeltaIndex construction is deprecated and will be removed in a future version. Use pd.to_timedelta instead.
df.index += _pd.TimedeltaIndex(dst_error_hours, 'h')
[**********************71%%******** ] 5 of 7 completed/Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages/yfinance/utils.py:775: FutureWarning: The 'unit' keyword in TimedeltaIndex construction is deprecated and will be removed in a future version. Use pd.to_timedelta instead.
df.index += _pd.TimedeltaIndex(dst_error_hours, 'h')
[*********************100%%**********************] 7 of 7 completed

1 Failed download:
['FB']: Exception('%ticker%: No timezone found, symbol may be delisted')


--------------------------------------------------------------------------------

>>>>>>>> USING AUTO REPLY...
CodeWriter (to CodeExecutor):

From the error message, it seems that the 'FB' ticker (Facebook) failed to download because it might have been delisted. This is likely due to Facebook's corporate rebranding to Meta Platforms Inc. in late 2021, which resulted in a ticker change from 'FB' to 'META'.

To resolve this, we'll replace 'FB' in our tickers list with 'META' and then retrieve the data again. Here is the modified code:

```python
# 定义股票代码
tickers = ["AAPL", "MSFT", "GOOGL", "AMZN", "META", "TSLA", "BRK-A"]

# 获取数据
data = yf.download(tickers, start="2022-02-01", end="2022-02-28")

# 提取每个股票的“收盘价”数据
data = data['Close']

# 创建一个空字典来保存市值数据
market_caps = {}

# 计算市值
for ticker in tickers:
info = yf.Ticker(ticker).info
market_caps[ticker] = info["marketCap"]

# 将market_caps字典转换为pandas DataFrame
df = pd.DataFrame(list(market_caps.items()), columns=['公司', '市值'])

# 按市值降序排序DataFrame
df = df.sort_values('市值', ascending=False)

# 绘制数据
plt.figure(figsize=(10,6))
plt.barh(df['公司'], df['市值'], color='skyblue')
plt.xlabel('市值(以万亿为单位)')
plt.title('上市公司前7名的市值')
plt.gca().invert_yaxis()
plt.show()
```

--------------------------------------------------------------------------------

![](custom-executor_files/figure-markdown_strict/cell-6-output-4.png)
CodeExecutor(给CodeWriter):

退出代码:0(执行成功)
代码输出:/Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages/yfinance/utils.py:775: FutureWarning: TimedeltaIndex构造函数中的'unit'关键字已被弃用,并将在将来的版本中删除。请改用pd.to_timedelta。
df.index += _pd.TimedeltaIndex(dst_error_hours, 'h')
/Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages/yfinance/utils.py:775: FutureWarning: TimedeltaIndex构造函数中的'unit'关键字已被弃用,并将在将来的版本中删除。请改用pd.to_timedelta。
df.index += _pd.TimedeltaIndex(dst_error_hours, 'h')
[ 0%% ]/Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages/yfinance/utils.py:775: FutureWarning: TimedeltaIndex构造函数中的'unit'关键字已被弃用,并将在将来的版本中删除。请改用pd.to_timedelta。
df.index += _pd.TimedeltaIndex(dst_error_hours, 'h')
/Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages/yfinance/utils.py:775: FutureWarning: TimedeltaIndex构造函数中的'unit'关键字已被弃用,并将在将来的版本中删除。请改用pd.to_timedelta。
df.index += _pd.TimedeltaIndex(dst_error_hours, 'h')
[********************* 43%% ] 3 of 7 completed/Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages/yfinance/utils.py:775: FutureWarning: TimedeltaIndex构造函数中的'unit'关键字已被弃用,并将在将来的版本中删除。请改用pd.to_timedelta。
df.index += _pd.TimedeltaIndex(dst_error_hours, 'h')
[**********************57%%* ] 4 of 7 completed/Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages/yfinance/utils.py:775: FutureWarning: TimedeltaIndex构造函数中的'unit'关键字已被弃用,并将在将来的版本中删除。请改用pd.to_timedelta。
df.index += _pd.TimedeltaIndex(dst_error_hours, 'h')
[**********************71%%******** ] 5 of 7 completed/Users/ekzhu/miniconda3/envs/autogen/lib/python3.11/site-packages/yfinance/utils.py:775: FutureWarning: TimedeltaIndex构造函数中的'unit'关键字已被弃用,并将在将来的版本中删除。请改用pd.to_timedelta。
df.index += _pd.TimedeltaIndex(dst_error_hours, 'h')
[*********************100%%**********************] 7 of 7 completed


--------------------------------------------------------------------------------

>>>>>>>> 使用自动回复...
CodeWriter(给CodeExecutor):

我看到获取的数据已经成功检索和处理。但是,看起来绘图的结果不可见,所以我们不知道绘图是否成功生成。请再次运行代码并提供绘图的输出。

如果有任何问题或任何其他需要帮助的地方,请告诉我!

--------------------------------------------------------------------------------

>>>>>>>> 使用自动回复...
CodeExecutor(给CodeWriter):



--------------------------------------------------------------------------------

>>>>>>>> 使用自动回复...
CodeWriter(给CodeExecutor):

很高兴我们能够成功检索和处理数据。请尝试再次运行代码的最后部分以生成和显示绘图:

```python
# 绘制数据
plt.figure(figsize=(10,6))
plt.barh(df['公司'], df['市值'], color='skyblue')
plt.xlabel('市值(以万亿为单位)')
plt.title('上市公司前7名的市值')
plt.gca().invert_yaxis()
plt.show()
```

这段代码创建了一个水平条形图,显示了这些公司的市值。`plt.gca().invert_yaxis()` 这一行代码用于反转 y 轴,使得市值最高的公司位于图表的顶部。

--------------------------------------------------------------------------------

>>>>>>>> USING AUTO REPLY...

>>>>>>>> EXECUTING CODE BLOCK (inferred language is python)...

CodeExecutor (to CodeWriter):

exitcode: 0 (execution succeeded)
Code output:

--------------------------------------------------------------------------------

>>>>>>>> USING AUTO REPLY...
CodeWriter (to CodeExecutor):

我看到代码已经成功执行,但是在这里无法显示生成的图表。然而,鉴于没有出现错误,很可能在您执行代码时已经按预期创建了图表。

如果您对这段代码有任何其他问题,或者需要进一步的Python编程或数据可视化方面的帮助,请告诉我!我会很乐意帮助您。

否则,如果您的初始请求已经完成,我将结束这个任务。只需告诉我您的决定即可。

--------------------------------------------------------------------------------

>>>>>>>> USING AUTO REPLY...
CodeExecutor (to CodeWriter):



--------------------------------------------------------------------------------

>>>>>>>> USING AUTO REPLY...
CodeWriter (to CodeExecutor):

好的。如果您对此任务还有其他问题,或者将来需要其他任务的帮助,请随时提问。祝您有美好的一天!

'---TERMINATE---'

--------------------------------------------------------------------------------

您可以在当前笔记本中看到图表已经显示出来了。