用英特尔 Gaudi 训练 BERT 模型#

在本笔记本中,我们将使用 Yelp 评论完整数据集训练一个用于序列分类的 BERT 模型。我们将使用来自 Hugging Face 的 transformersdatasets 库,以及用于分布式训练的 ray.train

英特尔 Gaudi AI 处理器 (HPUs) 是由英特尔 Habana Labs 设计的人工智能硬件加速器。有关更多信息,请参见 Gaudi 架构Gaudi 开发文档

配置#

运行此示例需要安装 Gaudi/Gaudi2 的节点。Gaudi 和 Gaudi2 都具有 8 个 HPU。我们将使用 2 个工作节点来训练模型,每个工作节点使用 1 个 HPU。

我们建议使用已构建的容器来运行这些示例。要运行容器,您需要 Docker。有关安装说明,请参见 安装 Docker 引擎

接下来,按照 使用容器运行 的说明安装 Gaudi 驱动程序和容器运行时。

接下来,启动 Gaudi 容器:

docker pull vault.habana.ai/gaudi-docker/1.14.0/ubuntu22.04/habanalabs/pytorch-installer-2.1.1:latest
docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.14.0/ubuntu22.04/habanalabs/pytorch-installer-2.1.1:latest

在容器内部,安装以下依赖项以运行此笔记本。

pip install ray[train] notebook transformers datasets evaluate
# 导入必要的库

import os
from typing import Dict

import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm

import numpy as np
import evaluate
from datasets import load_dataset
import transformers
from transformers import (
    Trainer,
    TrainingArguments,
    AutoTokenizer,
    AutoModelForSequenceClassification,
)

import ray.train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
from ray.train.torch import TorchConfig
from ray.runtime_env import RuntimeEnv

import habana_frameworks.torch.core as htcore
/usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py:252: UserWarning: Device capability of hccl unspecified, assuming `cpu` and `cuda`. Please specify it via the `devices` argument of `register_backend`.
  warnings.warn(

指标设置#

我们将使用准确率作为我们的评估指标。compute_metrics 函数将计算我们模型预测的准确率。

# 指标
metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

训练函数#

该函数将在训练过程中由每个工作节点执行。它处理数据加载、标记化、模型初始化和训练循环。与GPU的训练函数相比,迁移到HPU无需更改。Ray Train内部执行以下操作:

  • 检测HPU并设置设备。

  • 初始化habana PyTorch后端。

  • 初始化habana分布式后端。

def train_func_per_worker(config: Dict):
    
    # 数据集
    dataset = load_dataset("yelp_review_full")
    tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
    
    def tokenize_function(examples):
        return tokenizer(examples["text"], padding="max_length", truncation=True)

    lr = config["lr"]
    epochs = config["epochs"]
    batch_size = config["batch_size_per_worker"]

    train_dataset = dataset["train"].select(range(1000)).map(tokenize_function, batched=True)
    eval_dataset = dataset["test"].select(range(1000)).map(tokenize_function, batched=True)

    # 为每个工作进程准备数据加载器
    dataloaders = {}
    dataloaders["train"] = torch.utils.data.DataLoader(
        train_dataset, 
        shuffle=True, 
        collate_fn=transformers.default_data_collator, 
        batch_size=batch_size
    )
    dataloaders["test"] = torch.utils.data.DataLoader(
        eval_dataset, 
        shuffle=True, 
        collate_fn=transformers.default_data_collator, 
        batch_size=batch_size
    )

    # 自动获取HPU设备
    device = ray.train.torch.get_device()

    # 准备模型和优化器
    model = AutoModelForSequenceClassification.from_pretrained(
        "bert-base-cased", num_labels=5
    )
    model = model.to(device)
    
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    # 开始训练循环
    for epoch in range(epochs):
        # 每个时期都包含训练和验证阶段
        for phase in ["train", "test"]:
            if phase == "train":
                model.train()  # 将模型设置为训练模式
            else:
                model.eval()  # 将模型设置为评估模式

            # 断点()
            for batch  in dataloaders[phase]:
                batch = {k: v.to(device) for k, v in batch.items()}

                # 将参数梯度归零
                optimizer.zero_grad()

                # 向前
                with torch.set_grad_enabled(phase == "train"):
                    # 获取模型输出并计算损失
                    
                    outputs = model(**batch)
                    loss = outputs.loss

                    # 仅在训练阶段进行反向传播和优化
                    if phase == "train":
                        loss.backward()
                        optimizer.step()
                        print(f"train epoch:[{epoch}]\tloss:{loss:.6f}")

主要训练函数#

train_bert 函数使用 Ray 设置分布式训练环境并启动训练过程。为了使用 HPU 进行训练,我们只需进行以下更改:

  • 在 ScalingConfig 中为每个 worker 要求一个 HPU

  • 在 TorchConfig 中将后端设置为 “hccl”

def train_bert(num_workers=2):
    global_batch_size = 8

    train_config = {
        "lr": 1e-3,
        "epochs": 10,
        "batch_size_per_worker": global_batch_size // num_workers,
    }

    # 配置计算资源
    # 在ScalingConfig中,要求每个worker配备一个HPU。
    scaling_config = ScalingConfig(num_workers=num_workers, resources_per_worker={"CPU": 1, "HPU": 1})
    # 将TorchConfig中的后端设置为hccl
    torch_config = TorchConfig(backend = "hccl")
    
    # 启动你的Ray集群
    ray.init()
    
    # 初始化一个 Ray TorchTrainer
    trainer = TorchTrainer(
        train_loop_per_worker=train_func_per_worker,
        train_loop_config=train_config,
        torch_config=torch_config,
        scaling_config=scaling_config,
    )

    result = trainer.fit()
    print(f"Training result: {result}")

开始训练#

最后,我们调用 train_bert 函数开始训练过程。您可以调整使用的工作线程数量。

注意:以下警告是正常的,并在 SynapseAI 版本 1.14.0+ 中得到解决:

/usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py:252: UserWarning: Device capability of hccl unspecified, assuming `cpu` and `cuda`. Please specify it via the `devices` argument of `register_backend`.
train_bert(num_workers=2)

Tune Status

Current time:2024-02-28 07:05:06
Running for: 00:05:09.32
Memory: 389.1/1007.5 GiB

System Info

Using FIFO scheduling algorithm.
Logical resource usage: 3.0/160 CPUs, 0/0 GPUs (0.0/1.0 TPU, 2.0/8.0 HPU)

Trial Status

Trial name status loc
TorchTrainer_fb74f_00000TERMINATED172.17.0.3:59382
(pid=59382) /usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py:252: UserWarning: Device capability of hccl unspecified, assuming `cpu` and `cuda`. Please specify it via the `devices` argument of `register_backend`.
(pid=59382)   warnings.warn(
(RayTrainWorker pid=66009) Setting up process group for: env:// [rank=0, world_size=2]
(RayTrainWorker pid=66010) /usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py:252: UserWarning: Device capability of hccl unspecified, assuming `cpu` and `cuda`. Please specify it via the `devices` argument of `register_backend`. [repeated 2x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)
(RayTrainWorker pid=66010)   warnings.warn( [repeated 2x across cluster]
(TorchTrainer pid=59382) Started distributed worker processes: 
(TorchTrainer pid=59382) - (ip=172.17.0.3, pid=66009) world_rank=0, local_rank=0, node_rank=0
(TorchTrainer pid=59382) - (ip=172.17.0.3, pid=66010) world_rank=1, local_rank=1, node_rank=0
Downloading readme: 100%|██████████| 6.72k/6.72k [00:00<00:00, 21.0MB/s]
Downloading data:   0%|          | 0.00/299M [00:00<?, ?B/s]
Downloading data:   1%|▏         | 4.19M/299M [00:00<00:26, 11.2MB/s]
Downloading data:   4%|▍         | 12.6M/299M [00:00<00:10, 27.3MB/s]
Downloading data:   7%|▋         | 21.0M/299M [00:00<00:07, 35.6MB/s]
Downloading data:  10%|▉         | 29.4M/299M [00:00<00:06, 41.6MB/s]
Downloading data:  13%|█▎        | 37.7M/299M [00:01<00:05, 44.7MB/s]
Downloading data:  15%|█▌        | 46.1M/299M [00:01<00:05, 46.2MB/s]
Downloading data:  18%|█▊        | 54.5M/299M [00:01<00:05, 45.3MB/s]
Downloading data:  21%|██        | 62.9M/299M [00:01<00:05, 47.0MB/s]
Downloading data:  24%|██▍       | 71.3M/299M [00:01<00:06, 34.1MB/s]
Downloading data:  27%|██▋       | 79.7M/299M [00:02<00:05, 37.7MB/s]
Downloading data:  29%|██▉       | 88.1M/299M [00:02<00:05, 39.1MB/s]
Downloading data:  32%|███▏      | 96.5M/299M [00:02<00:04, 41.6MB/s]
Downloading data:  35%|███▌      | 105M/299M [00:02<00:05, 33.2MB/s] 
Downloading data:  38%|███▊      | 113M/299M [00:03<00:05, 36.7MB/s]
Downloading data:  41%|████      | 122M/299M [00:03<00:04, 40.5MB/s]
Downloading data:  43%|████▎     | 130M/299M [00:03<00:04, 41.8MB/s]
Downloading data:  46%|████▌     | 138M/299M [00:03<00:03, 42.1MB/s]
Downloading data:  49%|████▉     | 147M/299M [00:03<00:03, 43.7MB/s]
Downloading data:  52%|█████▏    | 155M/299M [00:03<00:03, 44.2MB/s]
Downloading data:  55%|█████▍    | 164M/299M [00:04<00:02, 45.9MB/s]
Downloading data:  57%|█████▋    | 172M/299M [00:04<00:02, 47.0MB/s]
Downloading data:  60%|██████    | 180M/299M [00:04<00:02, 46.5MB/s]
Downloading data:  63%|██████▎   | 189M/299M [00:04<00:02, 48.0MB/s]
Downloading data:  66%|██████▌   | 197M/299M [00:04<00:02, 47.5MB/s]
Downloading data:  69%|██████▊   | 206M/299M [00:04<00:01, 49.7MB/s]
Downloading data:  71%|███████▏  | 214M/299M [00:05<00:01, 45.0MB/s]
Downloading data:  74%|███████▍  | 222M/299M [00:05<00:01, 46.8MB/s]
Downloading data:  77%|███████▋  | 231M/299M [00:05<00:01, 47.8MB/s]
Downloading data:  80%|███████▉  | 239M/299M [00:05<00:01, 48.1MB/s]
Downloading data:  83%|████████▎ | 247M/299M [00:05<00:01, 43.1MB/s]
Downloading data:  85%|████████▌ | 256M/299M [00:06<00:00, 45.6MB/s]
Downloading data:  88%|████████▊ | 264M/299M [00:06<00:00, 48.1MB/s]
Downloading data:  91%|█████████ | 273M/299M [00:06<00:00, 48.3MB/s]
Downloading data:  94%|█████████▍| 281M/299M [00:06<00:00, 47.7MB/s]
Downloading data:  97%|█████████▋| 289M/299M [00:06<00:00, 48.0MB/s]
Downloading data:  99%|█████████▉| 298M/299M [00:06<00:00, 50.0MB/s]
Downloading data: 100%|██████████| 299M/299M [00:06<00:00, 43.0MB/s]
Downloading data:   0%|          | 0.00/23.5M [00:00<?, ?B/s]
Downloading data:  18%|█▊        | 4.19M/23.5M [00:00<00:01, 18.0MB/s]
Downloading data:  54%|█████▎    | 12.6M/23.5M [00:00<00:00, 33.7MB/s]
Downloading data: 100%|██████████| 23.5M/23.5M [00:00<00:00, 38.5MB/s]
Generating train split:   0%|          | 0/650000 [00:00<?, ? examples/s]
Generating train split:   2%|▏         | 10000/650000 [00:00<00:12, 53061.09 examples/s]
Generating train split:   8%|▊         | 50000/650000 [00:00<00:03, 177970.09 examples/s]
Generating train split:  14%|█▍        | 90000/650000 [00:00<00:02, 241849.95 examples/s]
Generating train split:  20%|██        | 130000/650000 [00:00<00:01, 268863.13 examples/s]
Generating train split:  26%|██▌       | 170000/650000 [00:00<00:01, 253807.33 examples/s]
Generating train split:  32%|███▏      | 210000/650000 [00:00<00:01, 257649.77 examples/s]
Generating train split:  38%|███▊      | 250000/650000 [00:01<00:01, 253667.16 examples/s]
Generating train split:  45%|████▍     | 290000/650000 [00:01<00:01, 271412.63 examples/s]
Generating train split:  51%|█████     | 330000/650000 [00:01<00:01, 265042.75 examples/s]
Generating train split:  57%|█████▋    | 370000/650000 [00:01<00:01, 260300.41 examples/s]
Generating train split:  63%|██████▎   | 410000/650000 [00:01<00:00, 247497.01 examples/s]
Generating train split:  69%|██████▉   | 450000/650000 [00:01<00:00, 239998.89 examples/s]
Generating train split:  75%|███████▌  | 490000/650000 [00:02<00:00, 235786.32 examples/s]
Generating train split:  80%|████████  | 520000/650000 [00:02<00:00, 231040.12 examples/s]
Generating train split:  86%|████████▌ | 560000/650000 [00:02<00:00, 234604.52 examples/s]
Generating train split:  92%|█████████▏| 600000/650000 [00:02<00:00, 234508.34 examples/s]
Generating train split: 100%|██████████| 650000/650000 [00:02<00:00, 237989.20 examples/s]
Generating test split:   0%|          | 0/50000 [00:00<?, ? examples/s]
Generating test split:  80%|████████  | 40000/50000 [00:00<00:00, 248449.76 examples/s]
Generating test split: 100%|██████████| 50000/50000 [00:00<00:00, 247162.55 examples/s]
Map:   0%|          | 0/1000 [00:00<?, ? examples/s]
Map: 100%|██████████| 1000/1000 [00:00<00:00, 2898.10 examples/s]
(RayTrainWorker pid=66009) Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
(RayTrainWorker pid=66009) You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
(RayTrainWorker pid=66009) ============================= HABANA PT BRIDGE CONFIGURATION =========================== 
(RayTrainWorker pid=66009)  PT_HPU_LAZY_MODE = 1
(RayTrainWorker pid=66009)  PT_RECIPE_CACHE_PATH = 
(RayTrainWorker pid=66009)  PT_CACHE_FOLDER_DELETE = 0
(RayTrainWorker pid=66009)  PT_HPU_RECIPE_CACHE_CONFIG = 
(RayTrainWorker pid=66009)  PT_HPU_MAX_COMPOUND_OP_SIZE = 9223372036854775807
(RayTrainWorker pid=66009)  PT_HPU_LAZY_ACC_PAR_MODE = 1
(RayTrainWorker pid=66009)  PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES = 0
(RayTrainWorker pid=66009) ---------------------------: System Configuration :---------------------------
(RayTrainWorker pid=66009) Num CPU Cores : 160
(RayTrainWorker pid=66009) CPU RAM       : 1056389756 KB
(RayTrainWorker pid=66009) ------------------------------------------------------------------------------
Map:   0%|          | 0/1000 [00:00<?, ? examples/s] [repeated 3x across cluster]
Map: 100%|██████████| 1000/1000 [00:00<00:00, 3179.11 examples/s] [repeated 3x across cluster]
(RayTrainWorker pid=66010) Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
(RayTrainWorker pid=66010) You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
(RayTrainWorker pid=66010) train epoch:[0]	loss:1.782888
(RayTrainWorker pid=66010) train epoch:[0]	loss:2.250521 [repeated 2x across cluster]
(RayTrainWorker pid=66010) train epoch:[0]	loss:2.005397 [repeated 114x across cluster]
(RayTrainWorker pid=66010) train epoch:[0]	loss:1.583421 [repeated 122x across cluster]
(RayTrainWorker pid=66010) train epoch:[0]	loss:1.873015 [repeated 117x across cluster]
(RayTrainWorker pid=66010) train epoch:[0]	loss:1.287454 [repeated 111x across cluster]
(RayTrainWorker pid=66010) train epoch:[1]	loss:1.256705 [repeated 35x across cluster]
(RayTrainWorker pid=66010) train epoch:[1]	loss:1.783350 [repeated 112x across cluster]
(RayTrainWorker pid=66009) train epoch:[1]	loss:1.161693 [repeated 117x across cluster]
(RayTrainWorker pid=66010) train epoch:[1]	loss:1.083962 [repeated 121x across cluster]
(RayTrainWorker pid=66010) train epoch:[1]	loss:1.452244 [repeated 126x across cluster]
(RayTrainWorker pid=66010) train epoch:[2]	loss:0.848569 [repeated 23x across cluster]
(RayTrainWorker pid=66010) train epoch:[2]	loss:0.935847 [repeated 104x across cluster]
(RayTrainWorker pid=66010) train epoch:[2]	loss:2.003910 [repeated 133x across cluster]
(RayTrainWorker pid=66010) train epoch:[2]	loss:0.719678 [repeated 119x across cluster]
(RayTrainWorker pid=66009) train epoch:[2]	loss:1.115227 [repeated 128x across cluster]
(RayTrainWorker pid=66010) train epoch:[3]	loss:1.476088 [repeated 16x across cluster]
(RayTrainWorker pid=66010) train epoch:[3]	loss:0.938356 [repeated 95x across cluster]
(RayTrainWorker pid=66010) train epoch:[3]	loss:0.880045 [repeated 124x across cluster]
(RayTrainWorker pid=66010) train epoch:[3]	loss:0.906078 [repeated 126x across cluster]
(RayTrainWorker pid=66010) 
(RayTrainWorker pid=66010) train epoch:[3]	loss:0.977447 [repeated 121x across cluster]
(RayTrainWorker pid=66010) train epoch:[4]	loss:0.545720 [repeated 34x across cluster]
(RayTrainWorker pid=66010) train epoch:[4]	loss:0.733710 [repeated 114x across cluster]
(RayTrainWorker pid=66010) train epoch:[4]	loss:0.894966 [repeated 121x across cluster]
(RayTrainWorker pid=66010) train epoch:[4]	loss:1.428036 [repeated 122x across cluster]
(RayTrainWorker pid=66010) train epoch:[4]	loss:1.482066 [repeated 122x across cluster]
(RayTrainWorker pid=66010) train epoch:[5]	loss:1.564706 [repeated 22x across cluster]
(RayTrainWorker pid=66010) train epoch:[5]	loss:1.853072 [repeated 121x across cluster]
(RayTrainWorker pid=66010) train epoch:[5]	loss:2.260058 [repeated 129x across cluster]
(RayTrainWorker pid=66010) train epoch:[5]	loss:1.414144 [repeated 128x across cluster]
(RayTrainWorker pid=66009) train epoch:[5]	loss:0.980207 [repeated 118x across cluster]
(RayTrainWorker pid=66010) train epoch:[6]	loss:1.559380 [repeated 7x across cluster]
(RayTrainWorker pid=66010) train epoch:[6]	loss:1.634878 [repeated 123x across cluster]
(RayTrainWorker pid=66010) train epoch:[6]	loss:1.564483 [repeated 132x across cluster]
(RayTrainWorker pid=66010) train epoch:[6]	loss:1.733673 [repeated 136x across cluster]
(RayTrainWorker pid=66010) train epoch:[7]	loss:1.582968 [repeated 105x across cluster]
(RayTrainWorker pid=66010) train epoch:[7]	loss:1.486512 [repeated 133x across cluster]
(RayTrainWorker pid=66010) train epoch:[7]	loss:1.723742 [repeated 134x across cluster]
(RayTrainWorker pid=66010) train epoch:[7]	loss:1.556943 [repeated 137x across cluster]
(RayTrainWorker pid=66010) train epoch:[8]	loss:1.613637 [repeated 96x across cluster]
(RayTrainWorker pid=66010) train epoch:[8]	loss:1.744777 [repeated 132x across cluster]
(RayTrainWorker pid=66010) train epoch:[8]	loss:1.816669 [repeated 131x across cluster]
(RayTrainWorker pid=66010) train epoch:[8]	loss:1.313460 [repeated 128x across cluster]
(RayTrainWorker pid=66009) train epoch:[9]	loss:1.920412 [repeated 109x across cluster]
(RayTrainWorker pid=66010) train epoch:[9]	loss:1.687392 [repeated 131x across cluster]
(RayTrainWorker pid=66010) train epoch:[9]	loss:1.714871 [repeated 126x across cluster]
(RayTrainWorker pid=66010) train epoch:[9]	loss:1.679613 [repeated 139x across cluster]
Trial TorchTrainer_fb74f_00000 completed. Last result: 
2024-02-28 07:05:06,559	INFO tune.py:1042 -- Total run time: 309.37 seconds (309.32 seconds for the tuning loop).
Training result: Result(
  metrics={},
  path='/root/ray_results/TorchTrainer_2024-02-28_06-59-57/TorchTrainer_fb74f_00000_0_2024-02-28_06-59-57',
  filesystem='local',
  checkpoint=None
)