跳到主要内容

使用Weights & Biases对ChatGPT-3.5和GPT-4进行微调

nbviewer Weights & Biases 在Colab中打开

注意: 运行此Colab需要一个OpenAI API密钥

如果您使用OpenAI的API来对ChatGPT-3.5进行微调,现在可以使用W&B集成来跟踪实验、模型和数据集在您的中央仪表板中。

只需一行代码:openai wandb sync

查看Weights & Biases文档中的OpenAI部分,获取集成的完整详情。

!pip install -Uq openai tiktoken datasets tenacity wandb

# 在该PR合并后删除:https://github.com/openai/openai-python/pull/590,并在OpenAI发布后进行。
!pip uninstall -y openai -qq \
&& pip install git+https://github.com/morganmcg1/openai-python.git@update_wandb_logger -qqq

可选:对ChatGPT-3.5进行微调

在自己的项目中进行实验总是更有趣的,所以如果您已经使用过openai API来对OpenAI模型进行微调,请跳过本节。

否则,让我们在一个法律数据集上对ChatGPT-3.5进行微调!

导入和初始设置

import openai
import wandb

import os
import json
import random
import tiktoken
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
from collections import defaultdict
from tenacity import retry, stop_after_attempt, wait_fixed

开始你的Weights & Biases运行。如果你还没有账号,你可以免费注册一个账号,网址是www.wandb.ai。

WANDB_PROJECT = "OpenAI-Fine-Tune"

设置您的API密钥

# # 输入凭证
openai_key = "YOUR_API_KEY"

openai.api_key = openai_key

数据集准备

我们从LegalBench下载了一个数据集,这是一个旨在策划用于评估法律推理的任务的项目,具体来说是合同NLI显式识别任务

这个数据集总共包含117个示例,我们将从中创建自己的训练集和测试集。

from datasets import load_dataset

# 下载数据,合并成一个数据集并进行随机洗牌
dataset = load_dataset("nguha/legalbench", "contract_nli_explicit_identification")

data = []
for d in dataset["train"]:
data.append(d)

for d in dataset["test"]:
data.append(d)

random.shuffle(data)

for idx, d in enumerate(data):
d["new_index"] = idx

让我们看几个示例。

len(data), data[0:2]

(117,
[{'answer': 'No',
'index': '94',
'text': 'Recipient shall use the Confidential Information exclusively for HySafe purposes, especially to advice the Governing Board of HySafe. ',
'document_name': 'NDA_V3.pdf',
'new_index': 0},
{'answer': 'No',
'index': '53',
'text': '3. In consideration of each and every disclosure of CONFIDENTIAL INFORMATION, the Parties agree to: (c) make no disclosures of any CONFIDENTIAL INFORMATION to any party other than officers and employees of a Party to this IRA; (d) limit access to CONFIDENTIAL INFORMATION to those officers and employees having a reasonable need for such INFORMATION and being boUnd by a written obligation to maintain the confidentiality of such INFORMATION; and ',
'document_name': '1084000_0001144204-06-046785_v056501_ex10-16.txt',
'new_index': 1}])

为聊天完成模型格式化我们的数据

我们修改了LegalBench任务中的base_prompt,将其改为零-shot提示,因为我们是在训练模型,而不是使用少量提示。

base_prompt_zero_shot = "Identify if the clause provides that all Confidential Information shall be expressly identified by the Disclosing Party. Answer with only `Yes` or `No`"

我们现在将数据集分成训练集和验证集,使用30个样本进行训练,剩余的样本用于测试。

n_train = 30
n_test = len(data) - n_train

train_messages = []
test_messages = []

for d in data:
prompts = []
prompts.append({"role": "system", "content": base_prompt_zero_shot})
prompts.append({"role": "user", "content": d["text"]})
prompts.append({"role": "assistant", "content": d["answer"]})

if int(d["new_index"]) < n_train:
train_messages.append({'messages': prompts})
else:
test_messages.append({'messages': prompts})

len(train_messages), len(test_messages), n_test, train_messages[5]

(30,
87,
87,
{'messages': [{'role': 'system',
'content': 'Identify if the clause provides that all Confidential Information shall be expressly identified by the Disclosing Party. Answer with only `Yes` or `No`'},
{'role': 'user',
'content': '2. The Contractor shall not, without the State’s prior written consent, copy, disclose, publish, release, transfer, disseminate, use, or allow access for any purpose or in any form, any Confidential Information except for the sole and exclusive purpose of performing under the Contract. '},
{'role': 'assistant', 'content': 'No'}]})

将数据保存到Weights & Biases

首先将数据保存在一个训练文件和一个测试文件中。

train_file_path = 'encoded_train_data.jsonl'
with open(train_file_path, 'w') as file:
for item in train_messages:
line = json.dumps(item)
file.write(line + '\n')

test_file_path = 'encoded_test_data.jsonl'
with open(test_file_path, 'w') as file:
for item in test_messages:
line = json.dumps(item)
file.write(line + '\n')

接下来,我们将使用来自OpenAI微调文档的脚本验证我们的训练数据是否处于正确的格式中。

# 接下来,我们指定数据路径并打开JSONL文件。

def openai_validate_data(dataset_path):
data_path = dataset_path

# 加载数据集
with open(data_path) as f:
dataset = [json.loads(line) for line in f]

# 我们可以通过检查示例的数量和第一个条目来快速查看数据。

# 初始数据集统计信息
print("Num examples:", len(dataset))
print("First example:")
for message in dataset[0]["messages"]:
print(message)

# 既然我们已经对数据有了一定的了解,接下来需要逐一检查所有不同的示例,确保其格式正确无误,并且符合聊天补全消息的结构要求。

# 格式错误检查
format_errors = defaultdict(int)

for ex in dataset:
if not isinstance(ex, dict):
format_errors["data_type"] += 1
continue

messages = ex.get("messages", None)
if not messages:
format_errors["missing_messages_list"] += 1
continue

for message in messages:
if "role" not in message or "content" not in message:
format_errors["message_missing_key"] += 1

if any(k not in ("role", "content", "name") for k in message):
format_errors["message_unrecognized_key"] += 1

if message.get("role", None) not in ("system", "user", "assistant"):
format_errors["unrecognized_role"] += 1

content = message.get("content", None)
if not content or not isinstance(content, str):
format_errors["missing_content"] += 1

if not any(message.get("role", None) == "assistant" for message in messages):
format_errors["example_missing_assistant_message"] += 1

if format_errors:
print("Found errors:")
for k, v in format_errors.items():
print(f"{k}: {v}")
else:
print("No errors found")

# 除了消息结构外,我们还需要确保长度不超过4096个令牌的限制。

# Token计数函数
encoding = tiktoken.get_encoding("cl100k_base")

# 不准确!
# 摘自 https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1):
num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3
return num_tokens

def num_assistant_tokens_from_messages(messages):
num_tokens = 0
for message in messages:
if message["role"] == "assistant":
num_tokens += len(encoding.encode(message["content"]))
return num_tokens

def print_distribution(values, name):
print(f"\n#### Distribution of {name}:")
print(f"min / max: {min(values)}, {max(values)}")
print(f"mean / median: {np.mean(values)}, {np.median(values)}")
print(f"p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}")

# 最后,我们可以在进行微调作业之前查看不同格式化操作的结果:

# 警告和令牌计数
n_missing_system = 0
n_missing_user = 0
n_messages = []
convo_lens = []
assistant_message_lens = []

for ex in dataset:
messages = ex["messages"]
if not any(message["role"] == "system" for message in messages):
n_missing_system += 1
if not any(message["role"] == "user" for message in messages):
n_missing_user += 1
n_messages.append(len(messages))
convo_lens.append(num_tokens_from_messages(messages))
assistant_message_lens.append(num_assistant_tokens_from_messages(messages))

print("Num examples missing system message:", n_missing_system)
print("Num examples missing user message:", n_missing_user)
print_distribution(n_messages, "num_messages_per_example")
print_distribution(convo_lens, "num_total_tokens_per_example")
print_distribution(assistant_message_lens, "num_assistant_tokens_per_example")
n_too_long = sum(l > 4096 for l in convo_lens)
print(f"\n{n_too_long} examples may be over the 4096 token limit, they will be truncated during fine-tuning")

# 定价与默认的n_epochs估计
MAX_TOKENS_PER_EXAMPLE = 4096

MIN_TARGET_EXAMPLES = 100
MAX_TARGET_EXAMPLES = 25000
TARGET_EPOCHS = 3
MIN_EPOCHS = 1
MAX_EPOCHS = 25

n_epochs = TARGET_EPOCHS
n_train_examples = len(dataset)
if n_train_examples * TARGET_EPOCHS < MIN_TARGET_EXAMPLES:
n_epochs = min(MAX_EPOCHS, MIN_TARGET_EXAMPLES // n_train_examples)
elif n_train_examples * TARGET_EPOCHS > MAX_TARGET_EXAMPLES:
n_epochs = max(MIN_EPOCHS, MAX_TARGET_EXAMPLES // n_train_examples)

n_billing_tokens_in_dataset = sum(min(MAX_TOKENS_PER_EXAMPLE, length) for length in convo_lens)
print(f"Dataset has ~{n_billing_tokens_in_dataset} tokens that will be charged for during training")
print(f"By default, you'll train for {n_epochs} epochs on this dataset")
print(f"By default, you'll be charged for ~{n_epochs * n_billing_tokens_in_dataset} tokens")
print("See pricing page to estimate total costs")

验证训练数据

openai_validate_data(train_file_path)

Num examples: 30
First example:
{'role': 'system', 'content': 'Identify if the clause provides that all Confidential Information shall be expressly identified by the Disclosing Party. Answer with only `Yes` or `No`'}
{'role': 'user', 'content': 'Recipient shall use the Confidential Information exclusively for HySafe purposes, especially to advice the Governing Board of HySafe. '}
{'role': 'assistant', 'content': 'No'}
No errors found
Num examples missing system message: 0
Num examples missing user message: 0

#### Distribution of num_messages_per_example:
min / max: 3, 3
mean / median: 3.0, 3.0
p5 / p95: 3.0, 3.0

#### Distribution of num_total_tokens_per_example:
min / max: 69, 319
mean / median: 143.46666666666667, 122.0
p5 / p95: 82.10000000000001, 235.10000000000002

#### Distribution of num_assistant_tokens_per_example:
min / max: 1, 1
mean / median: 1.0, 1.0
p5 / p95: 1.0, 1.0

0 examples may be over the 4096 token limit, they will be truncated during fine-tuning
Dataset has ~4304 tokens that will be charged for during training
By default, you'll train for 3 epochs on this dataset
By default, you'll be charged for ~12912 tokens
See pricing page to estimate total costs

将我们的数据记录到Weights & Biases Artifacts 中,以便进行存储和版本控制。

wandb.init(
project=WANDB_PROJECT,
# entity="prompt-eng",
job_type="log-data",
config = {'n_train': n_train,
'n_valid': n_test})

wandb.log_artifact(train_file_path,
"legalbench-contract_nli_explicit_identification-train",
type="train-data")

wandb.log_artifact(test_file_path,
"legalbench-contract_nli_explicit_identification-test",
type="test-data")

# 保留实体(通常是您的wandb用户名),以便在本演示的后续部分中引用工件。
entity = wandb.run.entity

wandb.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: capecape. Use `wandb login --relogin` to force relogin
Tracking run with wandb version 0.15.9
Run data is saved locally in /Users/tcapelle/work/examples/colabs/openai/wandb/run-20230830_113853-ivu21mjl
Waiting for W&B process to finish... (success).
wandb: WARNING Source type is set to 'repo' but some required information is missing from the environment. A job will not be created from this run. See https://docs.wandb.ai/guides/launch/create-job
View run mild-surf-1 at: https://wandb.ai/capecape/OpenAI-Fine-Tune/runs/ivu21mjl
Synced 6 W&B file(s), 0 media file(s), 2 artifact file(s) and 1 other file(s)
Find logs at: ./wandb/run-20230830_113853-ivu21mjl/logs

创建一个经过微调的模型

现在我们将使用OpenAI API来微调ChatGPT-3.5。

让我们首先下载我们的训练和验证文件,并将它们保存到一个名为my_data的文件夹中。我们将检索latest版本的artifact,但也可以是v0v1或我们关联的任何别名。

wandb.init(project=WANDB_PROJECT,
# entity="prompt-eng",
job_type="finetune")

artifact_train = wandb.use_artifact(
f'{entity}/{WANDB_PROJECT}/legalbench-contract_nli_explicit_identification-train:latest',
type='train-data')
train_file = artifact_train.get_path(train_file_path).download("my_data")

train_file

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016751802766035932, max=1.0…
Tracking run with wandb version 0.15.9
Run data is saved locally in /Users/tcapelle/work/examples/colabs/openai/wandb/run-20230830_113907-1ili9l51
'my_data/encoded_train_data.jsonl'

然后我们将训练数据上传到OpenAI。OpenAI需要处理这些数据,所以这将根据数据集的大小而花费几分钟的时间。

openai_train_file_info = openai.File.create(
file=open(train_file, "rb"),
purpose='fine-tune'
)

# 您可能需要等待几分钟,以便OpenAI处理该文件。
openai_train_file_info

<File file id=file-spPASR6VWco54SqfN2yo7T8v> JSON: {
"object": "file",
"id": "file-spPASR6VWco54SqfN2yo7T8v",
"purpose": "fine-tune",
"filename": "file",
"bytes": 24059,
"created_at": 1693388388,
"status": "uploaded",
"status_details": null
}

是时候训练模型了!

让我们定义ChatGPT-3.5微调的超参数。

model = 'gpt-3.5-turbo'
n_epochs = 3

openai_ft_job_info = openai.FineTuningJob.create(
training_file=openai_train_file_info["id"],
model=model,
hyperparameters={"n_epochs": n_epochs}
)

ft_job_id = openai_ft_job_info["id"]

openai_ft_job_info

<FineTuningJob fine_tuning.job id=ftjob-x4tl83IlSGolkUF3fCFyZNGs> JSON: {
"object": "fine_tuning.job",
"id": "ftjob-x4tl83IlSGolkUF3fCFyZNGs",
"model": "gpt-3.5-turbo-0613",
"created_at": 1693388447,
"finished_at": null,
"fine_tuned_model": null,
"organization_id": "org-WnF2wEqNkV1Nj65CzDxr6iUm",
"result_files": [],
"status": "created",
"validation_file": null,
"training_file": "file-spPASR6VWco54SqfN2yo7T8v",
"hyperparameters": {
"n_epochs": 3
},
"trained_tokens": null
}

这需要大约5分钟来训练,训练完成后您会收到来自OpenAI的电子邮件。

就是这样!

现在您的模型正在OpenAI的机器上训练。要获取微调作业的当前状态,请运行:

state = openai.FineTuningJob.retrieve(ft_job_id)
state["status"], state["trained_tokens"], state["finished_at"], state["fine_tuned_model"]

('succeeded',
12732,
1693389024,
'ft:gpt-3.5-turbo-0613:weights-biases::7tC85HcX')

显示我们的微调作业的最近事件。

openai.FineTuningJob.list_events(id=ft_job_id, limit=5)

<OpenAIObject list> JSON: {
"object": "list",
"data": [
{
"object": "fine_tuning.job.event",
"id": "ftevent-5x9Y6Payk6fIdyJyMRY5um1v",
"created_at": 1693389024,
"level": "info",
"message": "Fine-tuning job successfully completed",
"data": null,
"type": "message"
},
{
"object": "fine_tuning.job.event",
"id": "ftevent-i16NTGNakv9P0RkOtJ7vvvoG",
"created_at": 1693389022,
"level": "info",
"message": "New fine-tuned model created: ft:gpt-3.5-turbo-0613:weights-biases::7tC85HcX",
"data": null,
"type": "message"
},
{
"object": "fine_tuning.job.event",
"id": "ftevent-MkLrJQ8sDgaC67CdmFMwsIjV",
"created_at": 1693389017,
"level": "info",
"message": "Step 90/90: training loss=0.00",
"data": {
"step": 90,
"train_loss": 2.5828578600339824e-06,
"train_mean_token_accuracy": 1.0
},
"type": "metrics"
},
{
"object": "fine_tuning.job.event",
"id": "ftevent-3sRpTRSjK3TfFRZY88HEASpX",
"created_at": 1693389015,
"level": "info",
"message": "Step 89/90: training loss=0.00",
"data": {
"step": 89,
"train_loss": 2.5828578600339824e-06,
"train_mean_token_accuracy": 1.0
},
"type": "metrics"
},
{
"object": "fine_tuning.job.event",
"id": "ftevent-HtS6tJMVPOmazquZ82a1iCdV",
"created_at": 1693389015,
"level": "info",
"message": "Step 88/90: training loss=0.00",
"data": {
"step": 88,
"train_loss": 2.5828578600339824e-06,
"train_mean_token_accuracy": 1.0
},
"type": "metrics"
}
],
"has_more": true
}

我们可以使用不同的参数或不同的数据集运行几种不同的微调。

记录OpenAI的微调任务到Weights & Biases

import wandb

# Log metrics to W&B
wandb.init(project="openai-finetune", entity="openai")
wandb.log({"loss": 0.5, "accuracy": 0.9})

我们可以使用一个简单的命令记录我们的微调过程。

!openai wandb sync --help

usage: openai wandb sync [-h] [-i ID] [-n N_FINE_TUNES] [--project PROJECT]
[--entity ENTITY] [--force] [--legacy]

options:
-h, --help show this help message and exit
-i ID, --id ID The id of the fine-tune job (optional)
-n N_FINE_TUNES, --n_fine_tunes N_FINE_TUNES
Number of most recent fine-tunes to log when an id is
not provided. By default, every fine-tune is synced.
--project PROJECT Name of the Weights & Biases project where you're
sending runs. By default, it is "OpenAI-Fine-Tune".
--entity ENTITY Weights & Biases username or team name where you're
sending runs. By default, your default entity is used,
which is usually your username.
--force Forces logging and overwrite existing wandb run of the
same fine-tune.
--legacy Log results from legacy OpenAI /v1/fine-tunes api

调用openai wandb sync将会将所有未同步的微调作业记录到W&B中。

下面我们只是记录了一个作业,传入: - 我们的OpenAI密钥作为环境变量 - 我们想要记录的微调作业的ID - 要记录到的W&B项目

请查看Weights & Biases文档中的OpenAI部分,以获取集成的完整详情。

!OPENAI_API_KEY={openai_key} openai wandb sync --id {ft_job_id} --project {WANDB_PROJECT}

Retrieving fine-tune job...
wandb: Currently logged in as: capecape. Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.15.9
wandb: Run data is saved locally in /Users/tcapelle/work/examples/colabs/openai/wandb/run-20230830_115915-ftjob-x4tl83IlSGolkUF3fCFyZNGs
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run ftjob-x4tl83IlSGolkUF3fCFyZNGs
wandb: ⭐️ View project at https://wandb.ai/capecape/OpenAI-Fine-Tune
wandb: 🚀 View run at https://wandb.ai/capecape/OpenAI-Fine-Tune/runs/ftjob-x4tl83IlSGolkUF3fCFyZNGs
wandb: Waiting for W&B process to finish... (success).
wandb:
wandb: Run history:
wandb: train_accuracy ▁▁▁▁▁█▁█▁██▁████████████████████████████
wandb: train_loss █▇▆▂▂▁▂▁▅▁▁▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:
wandb: Run summary:
wandb: fine_tuned_model ft:gpt-3.5-turbo-061...
wandb: status succeeded
wandb: train_accuracy 1.0
wandb: train_loss 0.0
wandb:
wandb: 🚀 View run ftjob-x4tl83IlSGolkUF3fCFyZNGs at: https://wandb.ai/capecape/OpenAI-Fine-Tune/runs/ftjob-x4tl83IlSGolkUF3fCFyZNGs
wandb: Synced 6 W&B file(s), 0 media file(s), 1 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/run-20230830_115915-ftjob-x4tl83IlSGolkUF3fCFyZNGs/logs
🎉 wandb sync completed successfully
wandb.finish()

Waiting for W&B process to finish... (success).
VBox(children=(Label(value='0.050 MB of 0.050 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…
wandb: WARNING Source type is set to 'repo' but some required information is missing from the environment. A job will not be created from this run. See https://docs.wandb.ai/guides/launch/create-job
upload_file exception https://storage.googleapis.com/wandb-production.appspot.com/capecape/OpenAI-Fine-Tune/1ili9l51/requirements.txt?Expires=1693475972&GoogleAccessId=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com&Signature=NzF9wj2gS8rMEwRT9wlft2lNubemw67f2qrz9Zy90Bjxg5xCL9pIu%2FRbBGjRwLA2v64PuiP23Au5Dho26Tnw3UjUS1apqTkaOgjWDTlCCiDLzvMUsqHf0lhhWIgGMZcsA4gPpOi%2Bc%2ByJm4z6JE7D6RJ7r8y4fI0Jg6fX9KSWpzh8INiM6fQZiQjUChLVdtNJQZ2gfu7xRc%2BZIUEjgDuUqmS705pIUOgJXA9MS3%2Fhewkc7CxWay4ReMJixBZgaqLIRqHQnyzb38I5nPrRS3JrwrigQyX6tOsK05LDLA0o%2Bs0K11664%2F1ZxO6mSTfOaw7tXUmbUUWFOp33Qq8KXNz9Zg%3D%3D: ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
upload_file request headers: {'User-Agent': 'python-requests/2.28.2', 'Accept-Encoding': 'gzip, deflate, br', 'Accept': '*/*', 'Connection': 'keep-alive', 'Content-Length': '4902'}
upload_file response body:
upload_file exception https://storage.googleapis.com/wandb-production.appspot.com/capecape/OpenAI-Fine-Tune/1ili9l51/conda-environment.yaml?Expires=1693475972&GoogleAccessId=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com&Signature=wKnFdg7z7CiJOMn4WSvt6GSj2hPnMr0Xc4KuwAXa8akLucmw700x%2FWF87jmWaqnp%2FK4%2BF6JTRghQAokXF9jxCcXBSYhgFhCVACrOVyN%2BSTZ4u8tDgD6Dm%2FEFwWObiH%2BALSS1N0FmG7i6kL9Evyng3yPc4noEz%2FkLNIDIascAPgUe9UkPaBCRc9j7OxzYJx07bpeL4HaGe4yaCvk2mSVr4l%2FUfsICBI6E4KKrLDvtZvFFFUB4MgqXp0Sxc0k0pOxaw9zZhiNQQELDnhnuNY4wi78EPiXN1BpU6bTgIYaHe5mkS%2B7M5HiFs83ML98JI2OeRiAjAGtIIETT4xDjTYWVpA%3D%3D: ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
upload_file request headers: {'User-Agent': 'python-requests/2.28.2', 'Accept-Encoding': 'gzip, deflate, br', 'Accept': '*/*', 'Connection': 'keep-alive', 'Content-Length': '8450'}
upload_file response body:
View run jumping-water-2 at: https://wandb.ai/capecape/OpenAI-Fine-Tune/runs/1ili9l51
Synced 7 W&B file(s), 0 media file(s), 0 artifact file(s) and 1 other file(s)
Find logs at: ./wandb/run-20230830_113907-1ili9l51/logs

Our fine-tunes are now successfully synced to Weights & Biases.

image.png

Anytime we have new fine-tunes, we can just call openai wandb sync to add them to our dashboard.

运行评估并记录结果

评估生成模型的最佳方法是探索评估集中的样本预测。

让我们生成一些推理样本并将它们记录到W&B中,看看性能如何与基准的ChatGPT-3.5模型相比。

wandb.init(project=WANDB_PROJECT,
job_type='eval')

artifact_valid = wandb.use_artifact(
f'{entity}/{WANDB_PROJECT}/legalbench-contract_nli_explicit_identification-test:latest',
type='test-data')
test_file = artifact_valid.get_path(test_file_path).download("my_data")

with open(test_file) as f:
test_dataset = [json.loads(line) for line in f]

print(f"There are {len(test_dataset)} test examples")
wandb.config.update({"num_test_samples":len(test_dataset)})

Tracking run with wandb version 0.15.9
Run data is saved locally in /Users/tcapelle/work/examples/colabs/openai/wandb/run-20230830_115947-iepk19m2
There are 87 test examples

在Fine-Tuned模型上运行评估

设置OpenAI调用并进行重试

@retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
def call_openai(messages="", model="gpt-3.5-turbo"):
return openai.ChatCompletion.create(model=model, messages=messages, max_tokens=10)

让我们获取我们训练好的模型ID。

state = openai.FineTuningJob.retrieve(ft_job_id)
ft_model_id = state["fine_tuned_model"]
ft_model_id

'ft:gpt-3.5-turbo-0613:weights-biases::7tC85HcX'

运行评估并将结果记录到W&B

prediction_table = wandb.Table(columns=['messages', 'completion', 'target'])

eval_data = []

for row in tqdm(test_dataset):
messages = row['messages'][:2]
target = row["messages"][2]

# 结果 = 调用_openai(模型=ft_模型_id, 消息=消息)
res = openai.ChatCompletion.create(model=model, messages=messages, max_tokens=10)
completion = res.choices[0].message.content

eval_data.append([messages, completion, target])
prediction_table.add_data(messages[1]['content'], completion, target["content"])

wandb.log({'predictions': prediction_table})

  0%|          | 0/87 [00:00<?, ?it/s]

计算微调模型的准确率并记录到W&B。

correct = 0
for e in eval_data:
if e[1].lower() == e[2]["content"].lower():
correct+=1

accuracy = correct / len(eval_data)

print(f"Accuracy is {accuracy}")
wandb.log({"eval/accuracy": accuracy})
wandb.summary["eval/accuracy"] = accuracy

Accuracy is 0.8390804597701149

运行评估基准模型进行比较

让我们将我们的模型与基准模型 gpt-3.5-turbo 进行比较。

baseline_prediction_table = wandb.Table(columns=['messages', 'completion', 'target'])
baseline_eval_data = []

for row in tqdm(test_dataset):
messages = row['messages'][:2]
target = row["messages"][2]

res = call_openai(model="gpt-3.5-turbo", messages=messages)
completion = res.choices[0].message.content

baseline_eval_data.append([messages, completion, target])
baseline_prediction_table.add_data(messages[1]['content'], completion, target["content"])

wandb.log({'baseline_predictions': baseline_prediction_table})

  0%|          | 0/87 [00:00<?, ?it/s]

计算微调模型的准确率并记录到W&B。

baseline_correct = 0
for e in baseline_eval_data:
if e[1].lower() == e[2]["content"].lower():
baseline_correct+=1

baseline_accuracy = baseline_correct / len(baseline_eval_data)
print(f"Baseline Accurcy is: {baseline_accuracy}")
wandb.log({"eval/baseline_accuracy": baseline_accuracy})
wandb.summary["eval/baseline_accuracy"] = baseline_accuracy

Baseline Accurcy is: 0.7931034482758621
wandb.finish()

Waiting for W&B process to finish... (success).
VBox(children=(Label(value='0.248 MB of 0.248 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…
wandb: WARNING Source type is set to 'repo' but some required information is missing from the environment. A job will not be created from this run. See https://docs.wandb.ai/guides/launch/create-job

Run history:


eval/accuracy
eval/baseline_accuracy

Run summary:


eval/accuracy 0.83908
eval/baseline_accuracy 0.7931

View run ethereal-energy-4 at: https://wandb.ai/capecape/OpenAI-Fine-Tune/runs/iepk19m2
Synced 7 W&B file(s), 2 media file(s), 2 artifact file(s) and 1 other file(s)
Find logs at: ./wandb/run-20230830_115947-iepk19m2/logs

这就是全部内容!在这个示例中,我们准备了我们的数据,将其记录到Weights & Biases中,使用这些数据微调了一个OpenAI模型,将结果记录到Weights & Biases中,然后对微调后的模型进行了评估。

从这里开始,您可以开始在更大或更复杂的任务上进行训练,或者探索其他修改ChatGPT-3.5的方式,比如赋予它不同的语调和风格或响应方式。

资源