使用 KubeRay 和 GCSFuse 进行分布式检查点#

此示例使用 KubeRay 进行分布式检查点操作,使用 GCSFuse CSI 驱动程序和 Google Cloud Storage 作为远程存储系统。为了说明这些概念,本指南使用了 使用 Ray Train 微调 Pytorch 图像分类器 示例。

为什么使用 GCSFuse 进行分布式检查点?#

在大规模、高性能的机器学习中,分布式检查点对于容错至关重要,确保如果在训练过程中节点发生故障,Ray可以从最新的保存检查点恢复进程,而不是从头开始。虽然可以直接引用远程存储路径(例如,gs://my-checkpoint-bucket),但使用Google Cloud Storage FUSE(GCSFuse)对于分布式应用有明显的优势。GCSFuse允许你将Cloud Storage桶挂载为本地文件系统,使得依赖这些语义的分布式应用的检查点管理更加直观。此外,GCSFuse是为高性能工作负载设计的,能够提供分布式检查点所需的高性能和可扩展性,适用于大型模型的检查点。

分布式检查点,结合 GCSFuse,可以实现更大规模的模型训练,提高可用性和效率。

在 GKE 上创建一个 Kubernetes 集群#

创建一个启用了 GCSFuse CSI 驱动工作负载身份 的 GKE 集群,以及一个包含 4 个 L4 GPU 的 GPU 节点池:

export PROJECT_ID=<your project id>
gcloud container clusters create kuberay-with-gcsfuse \
    --addons GcsFuseCsiDriver \
    --cluster-version=1.29.4 \
    --location=us-east4-c \
    --machine-type=g2-standard-8 \
    --release-channel=rapid \
    --num-nodes=4 \
    --accelerator type=nvidia-l4,count=1,gpu-driver-version=latest \
    --workload-pool=${PROJECT_ID}.svc.id.goog

验证您的集群成功创建,包含4个GPU:

$ kubectl get nodes "-o=custom-columns=NAME:.metadata.name,GPU:.status.allocatable.nvidia\.com/gpu"
NAME                                                  GPU
gke-kuberay-with-gcsfuse-default-pool-xxxx-0000       1
gke-kuberay-with-gcsfuse-default-pool-xxxx-1111       1
gke-kuberay-with-gcsfuse-default-pool-xxxx-2222       1
gke-kuberay-with-gcsfuse-default-pool-xxxx-3333       1

安装 KubeRay 操作员#

按照 部署 KubeRay 操作员 来从 Helm 仓库安装最新稳定的 KubeRay 操作员。如果你正确地为 GPU 节点池设置了污点,那么 KubeRay 操作员 Pod 必须在 CPU 节点上。

配置GCS存储桶#

创建一个 Ray 用作远程文件系统的 GCS 存储桶。

BUCKET=<your GCS bucket>
gcloud storage buckets create gs://$BUCKET --uniform-bucket-level-access

创建一个 Kubernetes ServiceAccount,授予 RayCluster 挂载 GCS 存储桶的权限:

kubectl create serviceaccount pytorch-distributed-training

roles/storage.objectUser 角色绑定到 Kubernetes 服务账户和存储桶 IAM 策略。请参阅 识别项目 以找到您的项目 ID 和项目编号:

PROJECT_ID=<your project ID>
PROJECT_NUMBER=<your project number>
gcloud storage buckets add-iam-policy-binding gs://${BUCKET} --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/default/sa/pytorch-distributed-training"  --role "roles/storage.objectUser"

更多详情请参阅 使用 Cloud Storage FUSE CSI 驱动访问云存储桶

部署 RayJob#

下载执行 使用 Ray Train 微调 Pytorch 图像分类器 中记录的所有步骤的 RayJob。源代码 也在 KubeRay 仓库中。

curl -LO https://raw.githubusercontent.com/ray-project/kuberay/master/ray-operator/config/samples/pytorch-resnet-image-classifier/ray-job.pytorch-image-classifier.yaml

通过将 GCS_BUCKET 占位符替换为您之前创建的 Google Cloud Storage 存储桶来修改 RayJob。或者,您可以使用 sed

sed -i "s/GCS_BUCKET/$BUCKET/g" ray-job.pytorch-image-classifier.yaml

部署 RayJob:

kubectl create -f ray-job.pytorch-image-classifier.yaml

部署的 RayJob 包含以下配置,以启用分布式检查点到共享文件系统:

  • 4 个 Ray 工作节点,每个节点配备一个 GPU。

  • 所有 Ray 节点使用我们之前创建的 pytorch-distributed-training 服务账户。

  • 包含由 gcsfuse.csi.storage.gke.io CSI 驱动程序管理的卷。

  • 挂载一个共享存储路径 /mnt/cluster_storage,由您之前创建的GCS存储桶支持。

您可以通过注释配置 Pod,这允许对 GCSFuse 边车容器进行更精细的控制。更多详情请参见 指定 Pod 注释

annotations:
  gke-gcsfuse/volumes: "true"
  gke-gcsfuse/cpu-limit: "0"
  gke-gcsfuse/memory-limit: 5Gi
  gke-gcsfuse/ephemeral-storage-limit: 10Gi

在定义 GCSFuse 容器卷时,您还可以指定挂载选项:

csi:
  driver: gcsfuse.csi.storage.gke.io
  volumeAttributes:
    bucketName: GCS_BUCKET
    mountOptions: "implicit-dirs,uid=1000,gid=100"

参见 挂载选项 以了解更多关于挂载选项的信息。

Ray 作业的日志应指示在 /mnt/cluster_storage 中使用共享远程文件系统和检查点目录。例如:

Training finished iteration 10 at 2024-04-29 10:22:08. Total running time: 1min 30s
╭─────────────────────────────────────────╮
│ Training result                         │
├─────────────────────────────────────────┤
│ checkpoint_dir_name   checkpoint_000009 │
│ time_this_iter_s                6.47154 │
│ time_total_s                    74.5547 │
│ training_iteration                   10 │
│ acc                             0.24183 │
│ loss                            0.06882 │
╰─────────────────────────────────────────╯
Training saved a checkpoint for iteration 10 at: (local)/mnt/cluster_storage/finetune-resnet/TorchTrainer_cbb82_00000_0_2024-04-29_10-20-37/checkpoint_000009

检查检查点数据#

一旦 RayJob 完成,你可以使用像 gsutil 这样的工具来检查你的存储桶内容。

gsutil ls gs://my-ray-bucket/**
gs://my-ray-bucket/finetune-resnet/
gs://my-ray-bucket/finetune-resnet/.validate_storage_marker
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000007/
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000007/checkpoint.pt
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000008/
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000008/checkpoint.pt
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000009/
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000009/checkpoint.pt
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/error.pkl
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/error.txt
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/events.out.tfevents.1714436502.orch-image-classifier-nc2sq-raycluster-tdrfx-head-xzcl8
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/events.out.tfevents.1714436809.orch-image-classifier-zz4sj-raycluster-vn7kz-head-lwx8k
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/params.json
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/params.pkl
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/progress.csv
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/result.json
gs://my-ray-bucket/finetune-resnet/basic-variant-state-2024-04-29_17-21-29.json
gs://my-ray-bucket/finetune-resnet/basic-variant-state-2024-04-29_17-26-35.json
gs://my-ray-bucket/finetune-resnet/experiment_state-2024-04-29_17-21-29.json
gs://my-ray-bucket/finetune-resnet/experiment_state-2024-04-29_17-26-35.json
gs://my-ray-bucket/finetune-resnet/trainer.pkl
gs://my-ray-bucket/finetune-resnet/tuner.pkl

从检查点恢复#

在任务失败的情况下,您可以使用最新的检查点来恢复模型的训练。此示例配置 TorchTrainer 以自动从最新的检查点恢复:

experiment_path = os.path.expanduser("/mnt/cluster_storage/finetune-resnet")
if TorchTrainer.can_restore(experiment_path):
    trainer = TorchTrainer.restore(experiment_path,
        train_loop_per_worker=train_loop_per_worker,
        train_loop_config=train_loop_config,
        scaling_config=scaling_config,
        run_config=run_config,
    )
else:
    trainer = TorchTrainer(
        train_loop_per_worker=train_loop_per_worker,
        train_loop_config=train_loop_config,
        scaling_config=scaling_config,
        run_config=run_config,
    )

您可以通过重新部署相同的 RayJob 来验证自动检查点恢复:

kubectl create -f ray-job.pytorch-image-classifier.yaml

如果之前的任务成功,训练任务应该从 checkpoint_000009 目录恢复检查点状态,然后立即以0次迭代完成训练:

2024-04-29 15:51:32,528 INFO experiment_state.py:366 -- Trying to find and download experiment checkpoint at /mnt/cluster_storage/finetune-resnet
2024-04-29 15:51:32,651 INFO experiment_state.py:396 -- A remote experiment checkpoint was found and will be used to restore the previous experiment state.
2024-04-29 15:51:32,652 INFO tune_controller.py:404 -- Using the newest experiment state file found within the experiment directory: experiment_state-2024-04-29_15-43-40.json

View detailed results here: /mnt/cluster_storage/finetune-resnet
To visualize your results with TensorBoard, run: `tensorboard --logdir /home/ray/ray_results/finetune-resnet`

Result(
  metrics={'loss': 0.070047477101968, 'acc': 0.23529411764705882},
  path='/mnt/cluster_storage/finetune-resnet/TorchTrainer_ecc04_00000_0_2024-04-29_15-43-40',
  filesystem='local',
  checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/finetune-resnet/TorchTrainer_ecc04_00000_0_2024-04-29_15-43-40/checkpoint_000009)
)

如果之前的作业在较早的检查点失败,作业应从最后一个保存的检查点恢复并运行,直到 max_epochs=10。例如,如果上次运行在第7个epoch失败,训练将自动使用 checkpoint_000006 恢复,并再运行3个迭代,直到第10个epoch:

(TorchTrainer pid=611, ip=10.108.2.65) Restored on 10.108.2.65 from checkpoint: Checkpoint(filesystem=local, path=/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000006)
(RayTrainWorker pid=671, ip=10.108.2.65) Setting up process group for: env:// [rank=0, world_size=4]
(TorchTrainer pid=611, ip=10.108.2.65) Started distributed worker processes:
(TorchTrainer pid=611, ip=10.108.2.65) - (ip=10.108.2.65, pid=671) world_rank=0, local_rank=0, node_rank=0
(TorchTrainer pid=611, ip=10.108.2.65) - (ip=10.108.1.83, pid=589) world_rank=1, local_rank=0, node_rank=1
(TorchTrainer pid=611, ip=10.108.2.65) - (ip=10.108.0.72, pid=590) world_rank=2, local_rank=0, node_rank=2
(TorchTrainer pid=611, ip=10.108.2.65) - (ip=10.108.3.76, pid=590) world_rank=3, local_rank=0, node_rank=3
(RayTrainWorker pid=589, ip=10.108.1.83) Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/ray/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
(RayTrainWorker pid=671, ip=10.108.2.65)
  0%|          | 0.00/97.8M [00:00<?, ?B/s]
(RayTrainWorker pid=671, ip=10.108.2.65)
 22%|██▏       | 21.8M/97.8M [00:00<00:00, 229MB/s]
(RayTrainWorker pid=671, ip=10.108.2.65)
 92%|█████████▏| 89.7M/97.8M [00:00<00:00, 327MB/s]
(RayTrainWorker pid=671, ip=10.108.2.65)
100%|██████████| 97.8M/97.8M [00:00<00:00, 316MB/s]
(RayTrainWorker pid=671, ip=10.108.2.65) Moving model to device: cuda:0
(RayTrainWorker pid=671, ip=10.108.2.65) Wrapping provided model in DistributedDataParallel.
(RayTrainWorker pid=671, ip=10.108.2.65) Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/ray/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)
(RayTrainWorker pid=590, ip=10.108.3.76)
  0%|          | 0.00/97.8M [00:00<?, ?B/s] [repeated 3x across cluster]
(RayTrainWorker pid=590, ip=10.108.0.72)
 85%|████████▍ | 82.8M/97.8M [00:00<00:00, 256MB/s]
100%|██████████| 97.8M/97.8M [00:00<00:00, 231MB/s] [repeated 11x across cluster]
(RayTrainWorker pid=590, ip=10.108.3.76)
100%|██████████| 97.8M/97.8M [00:00<00:00, 238MB/s]
(RayTrainWorker pid=671, ip=10.108.2.65) Epoch 7-train Loss: 0.0903 Acc: 0.2418
(RayTrainWorker pid=671, ip=10.108.2.65) Epoch 7-val Loss: 0.0881 Acc: 0.2353
(RayTrainWorker pid=590, ip=10.108.0.72) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000007)
(RayTrainWorker pid=590, ip=10.108.0.72) Moving model to device: cuda:0 [repeated 3x across cluster]
(RayTrainWorker pid=590, ip=10.108.0.72) Wrapping provided model in DistributedDataParallel. [repeated 3x across cluster]

Training finished iteration 8 at 2024-04-29 17:27:29. Total running time: 54s
╭─────────────────────────────────────────╮
│ Training result                         │
├─────────────────────────────────────────┤
│ checkpoint_dir_name   checkpoint_000007 │
│ time_this_iter_s               40.46113 │
│ time_total_s                   95.00043 │
│ training_iteration                    8 │
│ acc                             0.23529 │
│ loss                            0.08811 │
╰─────────────────────────────────────────╯
Training saved a checkpoint for iteration 8 at: (local)/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000007
(RayTrainWorker pid=671, ip=10.108.2.65) Epoch 8-train Loss: 0.0893 Acc: 0.2459
(RayTrainWorker pid=671, ip=10.108.2.65) Epoch 8-val Loss: 0.0859 Acc: 0.2353
(RayTrainWorker pid=589, ip=10.108.1.83) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000008) [repeated 4x across cluster]

Training finished iteration 9 at 2024-04-29 17:27:36. Total running time: 1min 1s
╭─────────────────────────────────────────╮
│ Training result                         │
├─────────────────────────────────────────┤
│ checkpoint_dir_name   checkpoint_000008 │
│ time_this_iter_s                5.99923 │
│ time_total_s                  100.99965 │
│ training_iteration                    9 │
│ acc                             0.23529 │
│ loss                            0.08592 │
╰─────────────────────────────────────────╯
Training saved a checkpoint for iteration 9 at: (local)/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000008
2024-04-29 17:27:37,170 WARNING util.py:202 -- The `process_trial_save` operation took 0.540 s, which may be a performance bottleneck.
(RayTrainWorker pid=671, ip=10.108.2.65) Epoch 9-train Loss: 0.0866 Acc: 0.2377
(RayTrainWorker pid=671, ip=10.108.2.65) Epoch 9-val Loss: 0.0833 Acc: 0.2353
(RayTrainWorker pid=589, ip=10.108.1.83) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000009) [repeated 4x across cluster]

Training finished iteration 10 at 2024-04-29 17:27:43. Total running time: 1min 8s
╭─────────────────────────────────────────╮
│ Training result                         │
├─────────────────────────────────────────┤
│ checkpoint_dir_name   checkpoint_000009 │
│ time_this_iter_s                6.71457 │
│ time_total_s                  107.71422 │
│ training_iteration                   10 │
│ acc                             0.23529 │
│ loss                            0.08333 │
╰─────────────────────────────────────────╯
Training saved a checkpoint for iteration 10 at: (local)/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000009

Training completed after 10 iterations at 2024-04-29 17:27:45. Total running time: 1min 9s
2024-04-29 17:27:46,236 WARNING experiment_state.py:323 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.

Result(
  metrics={'loss': 0.08333033206416111, 'acc': 0.23529411764705882},
  path='/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29',
  filesystem='local',
  checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000009)
)
(RayTrainWorker pid=590, ip=10.108.3.76) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000009) [repeated 3x across cluster]