超参数优化¶
SentenceTransformerTrainer
支持使用 transformers
进行超参数优化,后者支持四种超参数搜索后端:optuna、sigopt、raytune 和 wandb。在使用之前,您应该安装您选择的后端:
pip install optuna/sigopt/wandb/ray[tune]
在本页中,我们将向您展示如何使用 optuna 后端进行超参数优化。其他后端的使用方式类似,但您应该参考它们各自的文档或 transformers HPO 文档 以获取更多信息。
HPO 组件¶
超参数优化过程由以下组件组成:
超参数搜索空间¶
超参数搜索空间由一个返回超参数及其各自搜索空间的字典的函数定义。以下是使用 optuna
定义 SentenceTransformer 模型超参数的搜索空间函数的示例:
def hpo_search_space(trial):
return {
"num_train_epochs": trial.suggest_int("num_train_epochs", 1, 2),
"per_device_train_batch_size": trial.suggest_int("per_device_train_batch_size", 32, 128),
"warmup_ratio": trial.suggest_float("warmup_ratio", 0, 0.3),
"learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
}
模型初始化¶
模型初始化函数是一个接受当前“试验”的超参数作为输入并返回 SentenceTransformer 模型的函数。通常,这个函数非常简单。以下是模型初始化函数的示例:
def hpo_model_init(trial):
return SentenceTransformer("distilbert-base-uncased")
损失初始化¶
损失初始化函数是一个接受为当前试验初始化的模型并返回损失函数的函数。以下是损失初始化函数的示例:
def hpo_loss_init(model):
return losses.CosineSimilarityLoss(model)
计算目标¶
计算目标函数是一个接受评估 metrics
并返回要最小化或最大化的浮点值的函数。以下是计算目标函数的示例:
def hpo_compute_objective(metrics):
return metrics["eval_sts-dev_spearman_cosine"]
整合所有组件¶
您可以在任何常规训练循环中执行 HPO,唯一的区别是您不调用 SentenceTransformerTrainer.train
,而是调用 SentenceTransformerTrainer.hyperparameter_search
。以下是如何将所有组件整合在一起的示例:
from sentence_transformers import losses
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SimilarityFunction
from sentence_transformers.training_args import BatchSamplers
from datasets import load_dataset
# 1. 加载 AllNLI 数据集:https://huggingface.co/datasets/sentence-transformers/all-nli,仅使用 10k 训练和 1k 开发
train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train[:10000]")
eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev[:1000]")
# 2. 创建一个评估器以执行有用的 HPO
stsb_eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")
dev_evaluator = EmbeddingSimilarityEvaluator(
sentences1=stsb_eval_dataset["sentence1"],
sentences2=stsb_eval_dataset["sentence2"],
scores=stsb_eval_dataset["score"],
main_similarity=SimilarityFunction.COSINE,
name="sts-dev",
)
# 3. 定义超参数搜索空间
def hpo_search_space(trial):
return {
"num_train_epochs": trial.suggest_int("num_train_epochs", 1, 2),
"per_device_train_batch_size": trial.suggest_int("per_device_train_batch_size", 32, 128),
"warmup_ratio": trial.suggest_float("warmup_ratio", 0, 0.3),
"learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
}
# 4. 定义模型初始化
def hpo_model_init(trial):
return SentenceTransformer("distilbert-base-uncased")
# 5. 定义损失初始化
def hpo_loss_init(model):
return losses.MultipleNegativesRankingLoss(model)
# 6. 定义目标函数
def hpo_compute_objective(metrics):
"""
有效的键包括:'eval_loss', 'eval_sts-dev_pearson_cosine', 'eval_sts-dev_spearman_cosine',
'eval_sts-dev_pearson_manhattan', 'eval_sts-dev_spearman_manhattan', 'eval_sts-dev_pearson_euclidean',
'eval_sts-dev_spearman_euclidean', 'eval_sts-dev_pearson_dot', 'eval_sts-dev_spearman_dot',
'eval_sts-dev_pearson_max', 'eval_sts-dev_spearman_max', 'eval_runtime', 'eval_samples_per_second',
'eval_steps_per_second', 'epoch'
由于我们使用的评估器。
"""
return metrics["eval_sts-dev_spearman_cosine"]
# 7. 定义训练参数
args = SentenceTransformerTrainingArguments(
# 必需参数:
output_dir="checkpoints",
# 可选训练参数:
# max_steps=10000, # 我们可能希望限制 HPO 的步数
fp16=True, # 如果您的 GPU 无法运行 FP16,请设置为 False
bf16=False, # 如果您的 GPU 支持 BF16,请设置为 True
batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss 受益于批次中没有重复样本
# 可选跟踪/调试参数:
eval_strategy="no", # 我们不需要在 HPO 期间进行评估/保存
save_strategy="no",
logging_steps=10,
run_name="hpo", # 如果安装了 `wandb`,将在 W&B 中使用
)
# 8. 使用 model_init 而不是 model 创建训练器
trainer = SentenceTransformerTrainer(
model=None,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
evaluator=dev_evaluator,
model_init=hpo_model_init,
loss=hpo_loss_init,
)
# 9. 执行 HPO
best_trial = trainer.hyperparameter_search(
hp_space=hpo_search_space,
compute_objective=hpo_compute_objective,
n_trials=20,
direction="maximize",
backend="optuna",
)
print(best_trial)
[I 2024-05-17 15:10:47,844] Trial 0 finished with value: 0.7889856589698055 and parameters: {'num_train_epochs': 1, 'per_device_train_batch_size': 123, 'warmup_ratio': 0.07380948785410107, 'learning_rate': 2.686331417509812e-06}. Best is trial 0 with value: 0.7889856589698055.
[I 2024-05-17 15:12:13,283] Trial 1 finished with value: 0.7927780672090986 and parameters: {'num_train_epochs': 2, 'per_device_train_batch_size': 69, 'warmup_ratio': 0.2927897848007451, 'learning_rate': 5.885372118095137e-06}. Best is trial 1 with value: 0.7927780672090986.
[I 2024-05-17 15:12:43,896] Trial 2 finished with value: 0.7684829743509601 and parameters: {'num_train_epochs': 1, 'per_device_train_batch_size': 114, 'warmup_ratio': 0.0739429232666916, 'learning_rate': 7.344415188959276e-05}. Best is trial 1 with value: 0.7927780672090986.
[I 2024-05-17 15:14:49,730] Trial 3 finished with value: 0.7873032743147989 and parameters: {'num_train_epochs': 2, 'per_device_train_batch_size': 43, 'warmup_ratio': 0.15184370143796674, 'learning_rate': 9.703232080395476e-06}. Best is trial 1 with value: 0.7927780672090986.
[I 2024-05-17 15:15:39,597] Trial 4 finished with value: 0.7759251781929949 and parameters: {'num_train_epochs': 2, 'per_device_train_batch_size': 127, 'warmup_ratio': 0.263946220093495, 'learning_rate': 1.231454337152625e-06}. Best is trial 1 with value: 0.7927780672090986.
[I 2024-05-17 15:17:02,191] Trial 5 finished with value: 0.7964580509886684 and parameters: {'num_train_epochs': 1, 'per_device_train_batch_size': 34, 'warmup_ratio': 0.2276865359631089, 'learning_rate': 7.889007438884571e-06}. Best is trial 5 with value: 0.7964580509886684.
[I 2024-05-17 15:18:55,559] Trial 6 finished with value: 0.7901878917859169 and parameters: {'num_train_epochs': 2, 'per_device_train_batch_size': 48, 'warmup_ratio': 0.23228838664572948, 'learning_rate': 2.883013292682523e-06}. Best is trial 5 with value: 0.7964580509886684.
[I 2024-05-17 15:20:27,027] Trial 7 finished with value: 0.7935671067660925 and parameters: {'num_train_epochs': 2, 'per_device_train_batch_size': 62, 'warmup_ratio': 0.22061123927198237, 'learning_rate': 2.95413457610349e-06}. Best is trial 5 with value: 0.7964580509886684.
[I 2024-05-17 15:22:23,147] Trial 8 finished with value: 0.7848123114933252 and parameters: {'num_train_epochs': 2, 'per_device_train_batch_size': 45, 'warmup_ratio': 0.23071701022961139, 'learning_rate': 9.793681667449783e-06}. Best is trial 5 with value: 0.7964580509886684.
[I 2024-05-17 15:22:52,826] Trial 9 finished with value: 0.7909708416168918 and parameters: {'num_train_epochs': 1, 'per_device_train_batch_size': 121, 'warmup_ratio': 0.22440506724181647, 'learning_rate': 4.0744671365843346e-05}. Best is trial 5 with value: 0.7964580509886684.
[I 2024-05-17 15:23:30,395] Trial 10 finished with value: 0.7928991732385567 and parameters: {'num_train_epochs': 1, 'per_device_train_batch_size': 89, 'warmup_ratio': 0.14607293301068847, 'learning_rate': 2.5557492055039498e-05}. Best is trial 5 with value: 0.7964580509886684.
[I 2024-05-17 15:24:18,024] Trial 11 finished with value: 0.7991870087507459 and parameters: {'num_train_epochs': 1, 'per_device_train_batch_size': 66, 'warmup_ratio': 0.16886154348739527, 'learning_rate': 3.705926066938032e-06}. Best is trial 11 with value: 0.7991870087507459.
[I 2024-05-17 15:25:44,198] Trial 12 finished with value: 0.7923304174306207 and parameters: {'num_train_epochs': 1, 'per_device_train_batch_size': 33, 'warmup_ratio': 0.15953772535423974, 'learning_rate': 1.8076298025704224e-05}. Best is trial 11 with value: 0.7991870087507459.
[I 2024-05-17 15:26:20,739] Trial 13 finished with value: 0.8020260244040395 and parameters: {'num_train_epochs': 1, 'per_device_train_batch_size': 90, 'warmup_ratio': 0.18105202625281253, 'learning_rate': 5.513908793512551e-06}. Best is trial 13 with value: 0.8020260244040395.
[I 2024-05-17 15:26:57,783] Trial 14 finished with value: 0.7571110256860063 and parameters: {'num_train_epochs': 1, 'per_device_train_batch_size': 95, 'warmup_ratio': 0.00122391151793258, 'learning_rate': 1.0432486633629492e-06}. Best is trial 13 with value: 0.8020260244040395.
[I 2024-05-17 15:27:32,581] Trial 15 finished with value: 0.8009013936824717 and parameters: {'num_train_epochs': 1, 'per_device_train_batch_size': 101, 'warmup_ratio': 0.1761274711346081, 'learning_rate': 4.5918293464430035e-06}. Best is trial 13 with value: 0.8020260244040395.
[I 2024-05-17 15:28:05,850] Trial 16 finished with value: 0.8017668050806169 and parameters: {'num_train_epochs': 1, 'per_device_train_batch_size': 103, 'warmup_ratio': 0.10766501647726355, 'learning_rate': 5.0309795522333e-06}. Best is trial 13 with value: 0.8020260244040395.
[I 2024-05-17 15:28:37,393] Trial 17 finished with value: 0.7769412380909586 and parameters: {'num_train_epochs': 1, 'per_device_train_batch_size': 108, 'warmup_ratio': 0.1036610178950246, 'learning_rate': 1.7747598626081271e-06}. Best is trial 13 with value: 0.8020260244040395.
[I 2024-05-17 15:29:19,340] Trial 18 finished with value: 0.8011921300048339 and parameters: {'num_train_epochs': 1, 'per_device_train_batch_size': 80, 'warmup_ratio': 0.117014165550441, 'learning_rate': 1.238558867958792e-05}. Best is trial 13 with value: 0.8020260244040395.
[I 2024-05-17 15:29:59,508] Trial 19 finished with value: 0.8027501854704168 and parameters: {'num_train_epochs': 1, 'per_device_train_batch_size': 84, 'warmup_ratio': 0.014601112207929548, 'learning_rate': 5.627813947769514e-06}. Best is trial 19 with value: 0.8027501854704168.
BestRun(run_id='19', objective=0.8027501854704168, hyperparameters={'num_train_epochs': 1, 'per_device_train_batch_size': 84, 'warmup_ratio': 0.014601112207929548, 'learning_rate': 5.627813947769514e-06}, run_summary=None)
如你所见,最强的超参数在STS(dev)基准测试中达到了 0.802 的Spearman相关性。作为参考,使用默认训练参数( per_device_train_batch_size=8
, learning_rate=5e-5
)训练的结果是 0.736,而基于经验选择的超参数( per_device_train_batch_size=64
, learning_rate=2e-5
)训练的结果是 0.783 的Spearman相关性。因此,HPO在这里被证明非常有效,显著提升了模型性能。
示例脚本¶
hpo_nli.py - 一个在AllNLI数据集上执行超参数优化的示例脚本。