使用Huggingface Transformers与Tune#

Huggingface Logo

示例#

"""
此示例使用了官方的huggingface transformers库中的`hyperparameter_search` API。
"""
import os

import ray
from ray import tune
from ray.tune import CLIReporter
from ray.tune.examples.pbt_transformers.utils import (
    download_data,
    build_compute_metrics_fn,
)
from ray.tune.schedulers import PopulationBasedTraining
from transformers import (
    glue_tasks_num_labels,
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    GlueDataset,
    GlueDataTrainingArguments,
    TrainingArguments,
)


def tune_transformer(num_samples=8, gpus_per_trial=0, smoke_test=False):
    data_dir_name = "./data" if not smoke_test else "./test_data"
    data_dir = os.path.abspath(os.path.join(os.getcwd(), data_dir_name))
    if not os.path.exists(data_dir):
        os.mkdir(data_dir, 0o755)

    # 根据需要进行更改。
    model_name = (
        "bert-base-uncased" if not smoke_test else "sshleifer/tiny-distilroberta-base"
    )
    task_name = "rte"

    task_data_dir = os.path.join(data_dir, task_name.upper())

    num_labels = glue_tasks_num_labels[task_name]

    config = AutoConfig.from_pretrained(
        model_name, num_labels=num_labels, finetuning_task=task_name
    )

    # 下载并缓存分词器、模型及特征
    print("Downloading and caching Tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # 触发令牌解析器下载到缓存
    print("Downloading and caching pre-trained model")
    AutoModelForSequenceClassification.from_pretrained(
        model_name,
        config=config,
    )

    def get_model():
        return AutoModelForSequenceClassification.from_pretrained(
            model_name,
            config=config,
        )

    # 下载数据。
    download_data(task_name, data_dir)

    data_args = GlueDataTrainingArguments(task_name=task_name, data_dir=task_data_dir)

    train_dataset = GlueDataset(
        data_args, tokenizer=tokenizer, mode="train", cache_dir=task_data_dir
    )
    eval_dataset = GlueDataset(
        data_args, tokenizer=tokenizer, mode="dev", cache_dir=task_data_dir
    )

    training_args = TrainingArguments(
        output_dir=".",
        learning_rate=1e-5,  # 配置
        do_train=True,
        do_eval=True,
        no_cuda=gpus_per_trial <= 0,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        num_train_epochs=2,  # 配置
        max_steps=-1,
        per_device_train_batch_size=16,  # 配置
        per_device_eval_batch_size=16,  # 配置
        warmup_steps=0,
        weight_decay=0.1,  # 配置
        logging_dir="./logs",
        skip_memory_metrics=True,
        report_to="none",
    )

    trainer = Trainer(
        model_init=get_model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=build_compute_metrics_fn(task_name),
    )

    tune_config = {
        "per_device_train_batch_size": 32,
        "per_device_eval_batch_size": 32,
        "num_train_epochs": tune.choice([2, 3, 4, 5]),
        "max_steps": 1 if smoke_test else -1,  # 用于冒烟测试。
    }

    scheduler = PopulationBasedTraining(
        time_attr="training_iteration",
        metric="eval_acc",
        mode="max",
        perturbation_interval=1,
        hyperparam_mutations={
            "weight_decay": tune.uniform(0.0, 0.3),
            "learning_rate": tune.uniform(1e-5, 5e-5),
            "per_device_train_batch_size": [16, 32, 64],
        },
    )

    reporter = CLIReporter(
        parameter_columns={
            "weight_decay": "w_decay",
            "learning_rate": "lr",
            "per_device_train_batch_size": "train_bs/gpu",
            "num_train_epochs": "num_epochs",
        },
        metric_columns=["eval_acc", "eval_loss", "epoch", "training_iteration"],
    )

    trainer.hyperparameter_search(
        hp_space=lambda _: tune_config,
        backend="ray",
        n_trials=num_samples,
        resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
        scheduler=scheduler,
        keep_checkpoints_num=1,
        checkpoint_score_attr="training_iteration",
        stop={"training_iteration": 1} if smoke_test else None,
        progress_reporter=reporter,
        local_dir="~/ray_results/",
        name="tune_transformer_pbt",
        log_to_file=True,
    )


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--smoke-test",
        default=True,
        action="store_true",
        help="Finish quickly for testing",
    )
    args, _ = parser.parse_known_args()

    ray.init()

    if args.smoke_test:
        tune_transformer(num_samples=1, gpus_per_trial=0, smoke_test=True)
    else:
        # 你可以在此处更改GPU的数量:
        tune_transformer(num_samples=8, gpus_per_trial=1)

实用代码,在上面的脚本中导入。

"""用于加载和缓存数据的工具。"""

import os
from typing import Callable, Dict
import numpy as np
from transformers import EvalPrediction
from transformers import glue_compute_metrics, glue_output_modes


def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
    """来自transformers/examples/text-classification/run_glue.py的功能"""
    output_mode = glue_output_modes[task_name]

    def compute_metrics_fn(p: EvalPrediction):
        if output_mode == "classification":
            preds = np.argmax(p.predictions, axis=1)
        elif output_mode == "regression":
            preds = np.squeeze(p.predictions)
        metrics = glue_compute_metrics(task_name, preds, p.label_ids)
        return metrics

    return compute_metrics_fn


def download_data(task_name, data_dir="./data"):
    # 下载RTE训练数据
    print("Downloading dataset.")
    import urllib
    import zipfile

    if task_name == "rte":
        url = "https://dl.fbaipublicfiles.com/glue/data/RTE.zip"
    else:
        raise ValueError("Unknown task: {}".format(task_name))
    data_file = os.path.join(data_dir, "{}.zip".format(task_name))
    if not os.path.exists(data_file):
        urllib.request.urlretrieve(url, data_file)
        with zipfile.ZipFile(data_file) as zip_ref:
            zip_ref.extractall(data_dir)
        print("Downloaded data for task {} to {}".format(task_name, data_dir))
    else:
        print(
            "Data already exists. Using downloaded data for task {} from {}".format(
                task_name, data_dir
            )
        )