"""调用GitHub的工具。"""
from __future__ import annotations
import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import requests
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.utils import get_from_dict_or_env
if TYPE_CHECKING:
from github.Issue import Issue
from github.PullRequest import PullRequest
def _import_tiktoken() -> Any:
"""导入tiktoken。"""
try:
import tiktoken
except ImportError:
raise ImportError(
"tiktoken is not installed. "
"Please install it with `pip install tiktoken`"
)
return tiktoken
[docs]class GitHubAPIWrapper(BaseModel):
"""GitHub API的包装器。"""
github: Any #: :meta private:
github_repo_instance: Any #: :meta private:
github_repository: Optional[str] = None
github_app_id: Optional[str] = None
github_app_private_key: Optional[str] = None
active_branch: Optional[str] = None
github_base_branch: Optional[str] = None
class Config:
"""此pydantic对象的配置。"""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""验证环境中是否存在API密钥和Python包。"""
github_repository = get_from_dict_or_env(
values, "github_repository", "GITHUB_REPOSITORY"
)
github_app_id = get_from_dict_or_env(values, "github_app_id", "GITHUB_APP_ID")
github_app_private_key = get_from_dict_or_env(
values, "github_app_private_key", "GITHUB_APP_PRIVATE_KEY"
)
try:
from github import Auth, GithubIntegration
except ImportError:
raise ImportError(
"PyGithub is not installed. "
"Please install it with `pip install PyGithub`"
)
try:
# interpret the key as a file path
# fallback to interpreting as the key itself
with open(github_app_private_key, "r") as f:
private_key = f.read()
except Exception:
private_key = github_app_private_key
auth = Auth.AppAuth(
github_app_id,
private_key,
)
gi = GithubIntegration(auth=auth)
installation = gi.get_installations()
if not installation:
raise ValueError(
f"Please make sure to install the created github app with id "
f"{github_app_id} on the repo: {github_repository}"
"More instructions can be found at "
"https://docs.github.com/en/apps/using-"
"github-apps/installing-your-own-github-app"
)
try:
installation = installation[0]
except ValueError as e:
raise ValueError(
"Please make sure to give correct github parameters "
f"Error message: {e}"
)
# create a GitHub instance:
g = installation.get_github_for_installation()
repo = g.get_repo(github_repository)
github_base_branch = get_from_dict_or_env(
values,
"github_base_branch",
"GITHUB_BASE_BRANCH",
default=repo.default_branch,
)
active_branch = get_from_dict_or_env(
values,
"active_branch",
"ACTIVE_BRANCH",
default=repo.default_branch,
)
values["github"] = g
values["github_repo_instance"] = repo
values["github_repository"] = github_repository
values["github_app_id"] = github_app_id
values["github_app_private_key"] = github_app_private_key
values["active_branch"] = active_branch
values["github_base_branch"] = github_base_branch
return values
[docs] def parse_issues(self, issues: List[Issue]) -> List[dict]:
"""从每个问题中提取标题和编号,并将它们放入字典中
参数:
issues(List[Issue]): Github问题对象的列表
返回:
List[dict]: 问题标题和编号的字典
"""
parsed = []
for issue in issues:
title = issue.title
number = issue.number
opened_by = issue.user.login if issue.user else None
issue_dict = {"title": title, "number": number}
if opened_by is not None:
issue_dict["opened_by"] = opened_by
parsed.append(issue_dict)
return parsed
[docs] def parse_pull_requests(self, pull_requests: List[PullRequest]) -> List[dict]:
"""从每个问题中提取标题和编号,并将它们放入字典中
参数:
issues(List[Issue]): Github问题对象的列表
返回:
List[dict]: 问题标题和编号的字典
"""
parsed = []
for pr in pull_requests:
parsed.append(
{
"title": pr.title,
"number": pr.number,
"commits": str(pr.commits),
"comments": str(pr.comments),
}
)
return parsed
[docs] def get_issues(self) -> str:
"""获取存储库中的所有未解决问题,不包括拉取请求
返回:
str: 一个包含问题数量以及每个问题标题和编号的纯文本报告。
"""
issues = self.github_repo_instance.get_issues(state="open")
# Filter out pull requests (part of GH issues object)
issues = [issue for issue in issues if not issue.pull_request]
if issues:
parsed_issues = self.parse_issues(issues)
parsed_issues_str = (
"Found " + str(len(parsed_issues)) + " issues:\n" + str(parsed_issues)
)
return parsed_issues_str
else:
return "No open issues available"
[docs] def list_open_pull_requests(self) -> str:
"""从存储库中获取所有打开的PR
返回:
str:包含PR数量以及每个PR的标题和编号的纯文本报告。
"""
# issues = self.github_repo_instance.get_issues(state="open")
pull_requests = self.github_repo_instance.get_pulls(state="open")
if pull_requests.totalCount > 0:
parsed_prs = self.parse_pull_requests(pull_requests)
parsed_prs_str = (
"Found " + str(len(parsed_prs)) + " pull requests:\n" + str(parsed_prs)
)
return parsed_prs_str
else:
return "No open pull requests available"
[docs] def list_files_in_main_branch(self) -> str:
"""获取存储库主分支中的所有文件。
返回:
str:包含文件路径和名称的纯文本报告。
"""
files: List[str] = []
try:
contents = self.github_repo_instance.get_contents(
"", ref=self.github_base_branch
)
for content in contents:
if content.type == "dir":
files.extend(self.get_files_from_directory(content.path))
else:
files.append(content.path)
if files:
files_str = "\n".join(files)
return f"Found {len(files)} files in the main branch:\n{files_str}"
else:
return "No files found in the main branch"
except Exception as e:
return str(e)
[docs] def set_active_branch(self, branch_name: str) -> str:
"""对于此代理,相当于`git checkout branch_name`。
从Github克隆格式。
如果分支不存在,则返回错误(作为字符串)。
"""
curr_branches = [
branch.name for branch in self.github_repo_instance.get_branches()
]
if branch_name in curr_branches:
self.active_branch = branch_name
return f"Switched to branch `{branch_name}`"
else:
return (
f"Error {branch_name} does not exist,"
f"in repo with current branches: {str(curr_branches)}"
)
[docs] def list_branches_in_repo(self) -> str:
"""获取存储库中所有分支的列表。
返回:
str:包含分支名称的纯文本报告。
"""
try:
branches = [
branch.name for branch in self.github_repo_instance.get_branches()
]
if branches:
branches_str = "\n".join(branches)
return (
f"Found {len(branches)} branches in the repository:"
f"\n{branches_str}"
)
else:
return "No branches found in the repository"
except Exception as e:
return str(e)
[docs] def create_branch(self, proposed_branch_name: str) -> str:
"""创建一个新的分支,并将其设置为活动的机器人分支。
相当于 `git switch -c proposed_branch_name`
如果提议的分支已经存在,我们会追加 _v1,然后 _v2...
直到找到一个唯一的名称。
返回:
str: 一个纯文本的成功消息。
"""
from github import GithubException
i = 0
new_branch_name = proposed_branch_name
base_branch = self.github_repo_instance.get_branch(
self.github_repo_instance.default_branch
)
for i in range(1000):
try:
self.github_repo_instance.create_git_ref(
ref=f"refs/heads/{new_branch_name}", sha=base_branch.commit.sha
)
self.active_branch = new_branch_name
return (
f"Branch '{new_branch_name}' "
"created successfully, and set as current active branch."
)
except GithubException as e:
if e.status == 422 and "Reference already exists" in e.data["message"]:
i += 1
new_branch_name = f"{proposed_branch_name}_v{i}"
else:
# Handle any other exceptions
print(f"Failed to create branch. Error: {e}") # noqa: T201
raise Exception(
"Unable to create branch name from proposed_branch_name: "
f"{proposed_branch_name}"
)
return (
"Unable to create branch. "
"At least 1000 branches exist with named derived from "
f"proposed_branch_name: `{proposed_branch_name}`"
)
[docs] def list_files_in_bot_branch(self) -> str:
"""获取存储库活动分支中的所有文件,
机器人用于进行更改的分支。
返回:
str:包含分支中文件路径的纯文本列表。
"""
files: List[str] = []
try:
contents = self.github_repo_instance.get_contents(
"", ref=self.active_branch
)
for content in contents:
if content.type == "dir":
files.extend(self.get_files_from_directory(content.path))
else:
files.append(content.path)
if files:
files_str = "\n".join(files)
return (
f"Found {len(files)} files in branch `{self.active_branch}`:\n"
f"{files_str}"
)
else:
return f"No files found in branch: `{self.active_branch}`"
except Exception as e:
return f"Error: {e}"
[docs] def get_files_from_directory(self, directory_path: str) -> str:
"""递归地从存储库中的目录中获取文件。
参数:
directory_path(str):目录路径
返回:
str:文件路径列表,或错误消息。
"""
from github import GithubException
files: List[str] = []
try:
contents = self.github_repo_instance.get_contents(
directory_path, ref=self.active_branch
)
except GithubException as e:
return f"Error: status code {e.status}, {e.message}"
for content in contents:
if content.type == "dir":
files.extend(self.get_files_from_directory(content.path))
else:
files.append(content.path)
return str(files)
[docs] def get_issue(self, issue_number: int) -> Dict[str, Any]:
"""获取特定问题及其前10条评论
参数:
issue_number(int): Github问题的编号
返回:
dict: 包含问题标题、内容、评论(字符串形式)以及开启问题的用户的用户名的字典
"""
issue = self.github_repo_instance.get_issue(number=issue_number)
page = 0
comments: List[dict] = []
while len(comments) <= 10:
comments_page = issue.get_comments().get_page(page)
if len(comments_page) == 0:
break
for comment in comments_page:
comments.append({"body": comment.body, "user": comment.user.login})
page += 1
opened_by = None
if issue.user and issue.user.login:
opened_by = issue.user.login
return {
"number": issue_number,
"title": issue.title,
"body": issue.body,
"comments": str(comments),
"opened_by": str(opened_by),
}
[docs] def list_pull_request_files(self, pr_number: int) -> List[Dict[str, Any]]:
"""获取PR中所有文件的完整文本。在前3k个标记后截断。
# TODO: 如果文件变得很长,通过ctags对文件进行摘要。
参数:
pr_number(int): Github上拉取请求的编号
返回:
dict: 包含问题标题、正文和评论的字符串的字典
"""
tiktoken = _import_tiktoken()
MAX_TOKENS_FOR_FILES = 3_000
pr_files = []
pr = self.github_repo_instance.get_pull(number=int(pr_number))
total_tokens = 0
page = 0
while True: # or while (total_tokens + tiktoken()) < MAX_TOKENS_FOR_FILES:
files_page = pr.get_files().get_page(page)
if len(files_page) == 0:
break
for file in files_page:
try:
file_metadata_response = requests.get(file.contents_url)
if file_metadata_response.status_code == 200:
download_url = json.loads(file_metadata_response.text)[
"download_url"
]
else:
print(f"Failed to download file: {file.contents_url}, skipping") # noqa: T201
continue
file_content_response = requests.get(download_url)
if file_content_response.status_code == 200:
# Save the content as a UTF-8 string
file_content = file_content_response.text
else:
print( # noqa: T201
"Failed downloading file content "
f"(Error {file_content_response.status_code}). Skipping"
)
continue
file_tokens = len(
tiktoken.get_encoding("cl100k_base").encode(
file_content + file.filename + "file_name file_contents"
)
)
if (total_tokens + file_tokens) < MAX_TOKENS_FOR_FILES:
pr_files.append(
{
"filename": file.filename,
"contents": file_content,
"additions": file.additions,
"deletions": file.deletions,
}
)
total_tokens += file_tokens
except Exception as e:
print(f"Error when reading files from a PR on github. {e}") # noqa: T201
page += 1
return pr_files
[docs] def get_pull_request(self, pr_number: int) -> Dict[str, Any]:
"""获取特定的拉取请求及其前10条评论,受 max_tokens 限制。
参数:
pr_number(int): Github 拉取请求的编号
max_tokens(int): 响应中的最大令牌数
返回:
dict: 包含拉取请求的标题、正文和评论的字符串的字典
"""
max_tokens = 2_000
pull = self.github_repo_instance.get_pull(number=pr_number)
total_tokens = 0
def get_tokens(text: str) -> int:
tiktoken = _import_tiktoken()
return len(tiktoken.get_encoding("cl100k_base").encode(text))
def add_to_dict(data_dict: Dict[str, Any], key: str, value: str) -> None:
nonlocal total_tokens # Declare total_tokens as nonlocal
tokens = get_tokens(value)
if total_tokens + tokens <= max_tokens:
data_dict[key] = value
total_tokens += tokens # Now this will modify the outer variable
response_dict: Dict[str, str] = {}
add_to_dict(response_dict, "title", pull.title)
add_to_dict(response_dict, "number", str(pr_number))
add_to_dict(response_dict, "body", pull.body)
comments: List[str] = []
page = 0
while len(comments) <= 10:
comments_page = pull.get_issue_comments().get_page(page)
if len(comments_page) == 0:
break
for comment in comments_page:
comment_str = str({"body": comment.body, "user": comment.user.login})
if total_tokens + get_tokens(comment_str) > max_tokens:
break
comments.append(comment_str)
total_tokens += get_tokens(comment_str)
page += 1
add_to_dict(response_dict, "comments", str(comments))
commits: List[str] = []
page = 0
while len(commits) <= 10:
commits_page = pull.get_commits().get_page(page)
if len(commits_page) == 0:
break
for commit in commits_page:
commit_str = str({"message": commit.commit.message})
if total_tokens + get_tokens(commit_str) > max_tokens:
break
commits.append(commit_str)
total_tokens += get_tokens(commit_str)
page += 1
add_to_dict(response_dict, "commits", str(commits))
return response_dict
[docs] def create_pull_request(self, pr_query: str) -> str:
""" 从机器人的分支向基础分支发起拉取请求
参数:
pr_query(str): 包含PR标题和PR正文的字符串。标题是字符串的第一行,正文是字符串的其余部分。
例如,"更新了README
进行了添加信息的更改"
返回:
str: 成功或失败消息
"""
if self.github_base_branch == self.active_branch:
return """Cannot make a pull request because
commits are already in the main or master branch."""
else:
try:
title = pr_query.split("\n")[0]
body = pr_query[len(title) + 2 :]
pr = self.github_repo_instance.create_pull(
title=title,
body=body,
head=self.active_branch,
base=self.github_base_branch,
)
return f"Successfully created PR number {pr.number}"
except Exception as e:
return "Unable to make pull request due to error:\n" + str(e)
[docs] def create_file(self, file_query: str) -> str:
""" 在Github仓库上创建一个新文件
参数:
file_query(str): 一个包含文件路径和文件内容的字符串。文件路径是字符串的第一行,内容是字符串的其余部分。
例如, "hello_world.md
# Hello World!"
返回:
str: 一个成功或失败的消息
"""
if self.active_branch == self.github_base_branch:
return (
"You're attempting to commit to the directly to the"
f"{self.github_base_branch} branch, which is protected. "
"Please create a new branch and try again."
)
file_path = file_query.split("\n")[0]
file_contents = file_query[len(file_path) + 2 :]
try:
try:
file = self.github_repo_instance.get_contents(
file_path, ref=self.active_branch
)
if file:
return (
f"File already exists at `{file_path}` "
f"on branch `{self.active_branch}`. You must use "
"`update_file` to modify it."
)
except Exception:
# expected behavior, file shouldn't exist yet
pass
self.github_repo_instance.create_file(
path=file_path,
message="Create " + file_path,
content=file_contents,
branch=self.active_branch,
)
return "Created file " + file_path
except Exception as e:
return "Unable to make file due to error:\n" + str(e)
[docs] def read_file(self, file_path: str) -> str:
"""从此代理的分支中读取一个文件,由self.active_branch定义,支持PR分支。
参数:
file_path(str): 文件路径
返回:
str: 解码为字符串的文件,如果未找到则返回错误消息
"""
try:
file = self.github_repo_instance.get_contents(
file_path, ref=self.active_branch
)
return file.decoded_content.decode("utf-8")
except Exception as e:
return (
f"File not found `{file_path}` on branch"
f"`{self.active_branch}`. Error: {str(e)}"
)
[docs] def update_file(self, file_query: str) -> str:
"""更新文件内容。
参数:
file_query(str): 包含文件路径和文件内容。
旧文件内容被包裹在 OLD <<<< 和 >>>> OLD 中
新文件内容被包裹在 NEW <<<< 和 >>>> NEW 中
例如:
/test/hello.txt
OLD <<<<
Hello Earth!
>>>> OLD
NEW <<<<
Hello Mars!
>>>> NEW
返回:
成功或失败消息
"""
if self.active_branch == self.github_base_branch:
return (
"You're attempting to commit to the directly"
f"to the {self.github_base_branch} branch, which is protected. "
"Please create a new branch and try again."
)
try:
file_path: str = file_query.split("\n")[0]
old_file_contents = (
file_query.split("OLD <<<<")[1].split(">>>> OLD")[0].strip()
)
new_file_contents = (
file_query.split("NEW <<<<")[1].split(">>>> NEW")[0].strip()
)
file_content = self.read_file(file_path)
updated_file_content = file_content.replace(
old_file_contents, new_file_contents
)
if file_content == updated_file_content:
return (
"File content was not updated because old content was not found."
"It may be helpful to use the read_file action to get "
"the current file contents."
)
self.github_repo_instance.update_file(
path=file_path,
message="Update " + str(file_path),
content=updated_file_content,
branch=self.active_branch,
sha=self.github_repo_instance.get_contents(
file_path, ref=self.active_branch
).sha,
)
return "Updated file " + str(file_path)
except Exception as e:
return "Unable to update file due to error:\n" + str(e)
[docs] def delete_file(self, file_path: str) -> str:
"""从仓库中删除一个文件
参数:
file_path(str): 文件所在的路径
返回:
str: 成功或失败的消息
"""
if self.active_branch == self.github_base_branch:
return (
"You're attempting to commit to the directly"
f"to the {self.github_base_branch} branch, which is protected. "
"Please create a new branch and try again."
)
try:
self.github_repo_instance.delete_file(
path=file_path,
message="Delete " + file_path,
branch=self.active_branch,
sha=self.github_repo_instance.get_contents(
file_path, ref=self.active_branch
).sha,
)
return "Deleted file " + file_path
except Exception as e:
return "Unable to delete file due to error:\n" + str(e)
[docs] def search_issues_and_prs(self, query: str) -> str:
"""搜索存储库中的问题和拉取请求。
参数:
query(str): 搜索查询
返回:
str: 包含前5个问题和拉取请求的字符串
"""
search_result = self.github.search_issues(query, repo=self.github_repository)
max_items = min(5, search_result.totalCount)
results = [f"Top {max_items} results:"]
for issue in search_result[:max_items]:
results.append(
f"Title: {issue.title}, Number: {issue.number}, State: {issue.state}"
)
return "\n".join(results)
[docs] def search_code(self, query: str) -> str:
"""在存储库中搜索代码。
# 待办事项:限制返回的总令牌数量...
参数:
query(str): 搜索查询
返回:
str: 包含最多前5个搜索结果的字符串
"""
search_result = self.github.search_code(
query=query, repo=self.github_repository
)
if search_result.totalCount == 0:
return "0 results found."
max_results = min(5, search_result.totalCount)
results = [f"Showing top {max_results} of {search_result.totalCount} results:"]
count = 0
for code in search_result:
if count >= max_results:
break
# Get the file content using the PyGithub get_contents method
file_content = self.github_repo_instance.get_contents(
code.path, ref=self.active_branch
).decoded_content.decode()
results.append(
f"Filepath: `{code.path}`\nFile contents: "
f"{file_content}\n<END OF FILE>"
)
count += 1
return "\n".join(results)
[docs] def create_review_request(self, reviewer_username: str) -> str:
"""在与当前活动分支匹配的*开放拉取请求*上创建一个审查请求。
参数:
reviewer_username(str): 被请求人的用户名
返回:
str: 确认创建审查请求的消息
"""
pull_requests = self.github_repo_instance.get_pulls(
state="open", sort="created"
)
# find PR against active_branch
pr = next(
(pr for pr in pull_requests if pr.head.ref == self.active_branch), None
)
if pr is None:
return (
"No open pull request found for the "
f"current branch `{self.active_branch}`"
)
try:
pr.create_review_request(reviewers=[reviewer_username])
return (
f"Review request created for user {reviewer_username} "
f"on PR #{pr.number}"
)
except Exception as e:
return f"Failed to create a review request with error {e}"
[docs] def run(self, mode: str, query: str) -> str:
if mode == "get_issue":
return json.dumps(self.get_issue(int(query)))
elif mode == "get_pull_request":
return json.dumps(self.get_pull_request(int(query)))
elif mode == "list_pull_request_files":
return json.dumps(self.list_pull_request_files(int(query)))
elif mode == "get_issues":
return self.get_issues()
elif mode == "comment_on_issue":
return self.comment_on_issue(query)
elif mode == "create_file":
return self.create_file(query)
elif mode == "create_pull_request":
return self.create_pull_request(query)
elif mode == "read_file":
return self.read_file(query)
elif mode == "update_file":
return self.update_file(query)
elif mode == "delete_file":
return self.delete_file(query)
elif mode == "list_open_pull_requests":
return self.list_open_pull_requests()
elif mode == "list_files_in_main_branch":
return self.list_files_in_main_branch()
elif mode == "list_files_in_bot_branch":
return self.list_files_in_bot_branch()
elif mode == "list_branches_in_repo":
return self.list_branches_in_repo()
elif mode == "set_active_branch":
return self.set_active_branch(query)
elif mode == "create_branch":
return self.create_branch(query)
elif mode == "get_files_from_directory":
return self.get_files_from_directory(query)
elif mode == "search_issues_and_prs":
return self.search_issues_and_prs(query)
elif mode == "search_code":
return self.search_code(query)
elif mode == "create_review_request":
return self.create_review_request(query)
else:
raise ValueError("Invalid mode" + mode)