使用嵌入进行代码搜索
本笔记展示了如何使用Ada嵌入来实现语义代码搜索。在这个演示中,我们使用我们自己的openai-python代码库。我们实现了一个简单版本的文件解析和从python文件中提取函数的功能,这些函数可以被嵌入、索引和查询。
辅助函数
我们首先设置一些简单的解析函数,这些函数允许我们从我们的代码库中提取重要信息。
import pandas as pd
from pathlib import Path
DEF_PREFIXES = ['def ', 'async def ']
NEWLINE = '\n'
def get_function_name(code):
"""
从以'def'或'async def'开头的行中提取函数名。
"""
for prefix in DEF_PREFIXES:
if code.startswith(prefix):
return code[len(prefix): code.index('(')]
def get_until_no_space(all_lines, i):
"""
获取所有行,直到找到函数定义之外的行。
"""
ret = [all_lines[i]]
for j in range(i + 1, len(all_lines)):
if len(all_lines[j]) == 0 or all_lines[j][0] in [' ', '\t', ')']:
ret.append(all_lines[j])
else:
break
return NEWLINE.join(ret)
def get_functions(filepath):
"""
获取Python文件中的所有函数。
"""
with open(filepath, 'r') as file:
all_lines = file.read().replace('\r', NEWLINE).split(NEWLINE)
for i, l in enumerate(all_lines):
for prefix in DEF_PREFIXES:
if l.startswith(prefix):
code = get_until_no_space(all_lines, i)
function_name = get_function_name(code)
yield {
'code': code,
'function_name': function_name,
'filepath': filepath,
}
break
def extract_functions_from_repo(code_root):
"""
从仓库中提取所有.py文件中的函数。
"""
code_files = list(code_root.glob('**/*.py'))
num_files = len(code_files)
print(f'Total number of .py files: {num_files}')
if num_files == 0:
print('Verify openai-python repo exists and code_root is set correctly.')
return None
all_funcs = [
func
for code_file in code_files
for func in get_functions(str(code_file))
]
num_funcs = len(all_funcs)
print(f'Total number of functions extracted: {num_funcs}')
return all_funcs
数据加载
我们首先加载openai-python文件夹,并使用我们上面定义的函数提取所需的信息。
# Set user root directory to the 'openai-python' repository
root_dir = Path.home()
# Assumes the 'openai-python' repository exists in the user's root directory
code_root = root_dir / 'openai-python'
# 从仓库中提取所有功能
all_funcs = extract_functions_from_repo(code_root)
Total number of .py files: 51
Total number of functions extracted: 97
现在我们已经有了我们的内容,我们可以将数据传递给text-embedding-3-small
模型,并获得我们的向量嵌入。
from utils.embeddings_utils import get_embedding
df = pd.DataFrame(all_funcs)
df['code_embedding'] = df['code'].apply(lambda x: get_embedding(x, model='text-embedding-3-small'))
df['filepath'] = df['filepath'].map(lambda x: Path(x).relative_to(code_root))
df.to_csv("data/code_search_openai-python.csv", index=False)
df.head()
code | function_name | filepath | code_embedding | |
---|---|---|---|---|
0 | def _console_log_level(): if openai.log i... | _console_log_level | openai/util.py | [0.005937571171671152, 0.05450401455163956, 0.... |
1 | def log_debug(message, **params): msg = l... | log_debug | openai/util.py | [0.017557814717292786, 0.05647840350866318, -0... |
2 | def log_info(message, **params): msg = lo... | log_info | openai/util.py | [0.022524144500494003, 0.06219055876135826, -0... |
3 | def log_warn(message, **params): msg = lo... | log_warn | openai/util.py | [0.030524108558893204, 0.0667714849114418, -0.... |
4 | def logfmt(props): def fmt(key, val): ... | logfmt | openai/util.py | [0.05337328091263771, 0.03697286546230316, -0.... |
测试
让我们使用一些简单的查询来测试我们的端点。如果您熟悉openai-python
存储库,您会发现我们可以很容易地通过简单的英文描述找到我们要查找的函数。
我们定义了一个search_functions
方法,该方法接受包含嵌入、查询字符串和一些其他配置选项的数据。搜索我们的数据库的过程如下:
- 首先,我们使用
text-embedding-3-small
对我们的查询字符串(code_query
)进行嵌入。这里的理由是,像’a function that reverses a string’这样的查询字符串和像’def reverse(string): return string[::-1]’这样的函数在嵌入时会非常相似。 - 然后,我们计算我们的查询字符串嵌入与数据库中所有数据点之间的余弦相似度。这会给出每个点与我们的查询之间的距离。
- 最后,我们按照它们与我们的查询字符串的距离对所有数据点进行排序,并返回函数参数中请求的结果数量。
from utils.embeddings_utils import cosine_similarity
def search_functions(df, code_query, n=3, pprint=True, n_lines=7):
embedding = get_embedding(code_query, model='text-embedding-3-small')
df['similarities'] = df.code_embedding.apply(lambda x: cosine_similarity(x, embedding))
res = df.sort_values('similarities', ascending=False).head(n)
if pprint:
for r in res.iterrows():
print(f"{r[1].filepath}:{r[1].function_name} score={round(r[1].similarities, 3)}")
print("\n".join(r[1].code.split("\n")[:n_lines]))
print('-' * 70)
return res
res = search_functions(df, 'fine-tuning input data validation logic', n=3)
openai/validators.py:format_inferrer_validator score=0.453
def format_inferrer_validator(df):
"""
This validator will infer the likely fine-tuning format of the data, and display it to the user if it is classification.
It will also suggest to use ada and explain train/validation split benefits.
"""
ft_type = infer_task_type(df)
immediate_msg = None
----------------------------------------------------------------------
openai/validators.py:infer_task_type score=0.37
def infer_task_type(df):
"""
Infer the likely fine-tuning task type from the data
"""
CLASSIFICATION_THRESHOLD = 3 # min_average instances of each class
if sum(df.prompt.str.len()) == 0:
return "open-ended generation"
----------------------------------------------------------------------
openai/validators.py:apply_validators score=0.369
def apply_validators(
df,
fname,
remediation,
validators,
auto_accept,
write_out_file_func,
----------------------------------------------------------------------
res = search_functions(df, 'find common suffix', n=2, n_lines=10)
openai/validators.py:get_common_xfix score=0.487
def get_common_xfix(series, xfix="suffix"):
"""
Finds the longest common suffix or prefix of all the values in a series
"""
common_xfix = ""
while True:
common_xfixes = (
series.str[-(len(common_xfix) + 1) :]
if xfix == "suffix"
else series.str[: len(common_xfix) + 1]
----------------------------------------------------------------------
openai/validators.py:common_completion_suffix_validator score=0.449
def common_completion_suffix_validator(df):
"""
This validator will suggest to add a common suffix to the completion if one doesn't already exist in case of classification or conditional generation.
"""
error_msg = None
immediate_msg = None
optional_msg = None
optional_fn = None
ft_type = infer_task_type(df)
----------------------------------------------------------------------
res = search_functions(df, 'Command line interface for fine-tuning', n=1, n_lines=20)
openai/cli.py:tools_register score=0.391
def tools_register(parser):
subparsers = parser.add_subparsers(
title="Tools", help="Convenience client side tools"
)
def help(args):
parser.print_help()
parser.set_defaults(func=help)
sub = subparsers.add_parser("fine_tunes.prepare_data")
sub.add_argument(
"-f",
"--file",
required=True,
help="JSONL, JSON, CSV, TSV, TXT or XLSX file containing prompt-completion examples to be analyzed."
"This should be the local file path.",
)
sub.add_argument(
"-q",
----------------------------------------------------------------------