使用 KubeRay 和 GCSFuse 进行分布式检查点#
此示例使用 KubeRay 进行分布式检查点操作,使用 GCSFuse CSI 驱动程序和 Google Cloud Storage 作为远程存储系统。为了说明这些概念,本指南使用了 使用 Ray Train 微调 Pytorch 图像分类器 示例。
为什么使用 GCSFuse 进行分布式检查点?#
在大规模、高性能的机器学习中,分布式检查点对于容错至关重要,确保如果在训练过程中节点发生故障,Ray可以从最新的保存检查点恢复进程,而不是从头开始。虽然可以直接引用远程存储路径(例如,gs://my-checkpoint-bucket
),但使用Google Cloud Storage FUSE(GCSFuse)对于分布式应用有明显的优势。GCSFuse允许你将Cloud Storage桶挂载为本地文件系统,使得依赖这些语义的分布式应用的检查点管理更加直观。此外,GCSFuse是为高性能工作负载设计的,能够提供分布式检查点所需的高性能和可扩展性,适用于大型模型的检查点。
在 GKE 上创建一个 Kubernetes 集群#
创建一个启用了 GCSFuse CSI 驱动 和 工作负载身份 的 GKE 集群,以及一个包含 4 个 L4 GPU 的 GPU 节点池:
export PROJECT_ID=<your project id>
gcloud container clusters create kuberay-with-gcsfuse \
--addons GcsFuseCsiDriver \
--cluster-version=1.29.4 \
--location=us-east4-c \
--machine-type=g2-standard-8 \
--release-channel=rapid \
--num-nodes=4 \
--accelerator type=nvidia-l4,count=1,gpu-driver-version=latest \
--workload-pool=${PROJECT_ID}.svc.id.goog
验证您的集群成功创建,包含4个GPU:
$ kubectl get nodes "-o=custom-columns=NAME:.metadata.name,GPU:.status.allocatable.nvidia\.com/gpu"
NAME GPU
gke-kuberay-with-gcsfuse-default-pool-xxxx-0000 1
gke-kuberay-with-gcsfuse-default-pool-xxxx-1111 1
gke-kuberay-with-gcsfuse-default-pool-xxxx-2222 1
gke-kuberay-with-gcsfuse-default-pool-xxxx-3333 1
安装 KubeRay 操作员#
按照 部署 KubeRay 操作员 来从 Helm 仓库安装最新稳定的 KubeRay 操作员。如果你正确地为 GPU 节点池设置了污点,那么 KubeRay 操作员 Pod 必须在 CPU 节点上。
配置GCS存储桶#
创建一个 Ray 用作远程文件系统的 GCS 存储桶。
BUCKET=<your GCS bucket>
gcloud storage buckets create gs://$BUCKET --uniform-bucket-level-access
创建一个 Kubernetes ServiceAccount,授予 RayCluster 挂载 GCS 存储桶的权限:
kubectl create serviceaccount pytorch-distributed-training
将 roles/storage.objectUser
角色绑定到 Kubernetes 服务账户和存储桶 IAM 策略。请参阅 识别项目 以找到您的项目 ID 和项目编号:
PROJECT_ID=<your project ID>
PROJECT_NUMBER=<your project number>
gcloud storage buckets add-iam-policy-binding gs://${BUCKET} --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/default/sa/pytorch-distributed-training" --role "roles/storage.objectUser"
更多详情请参阅 使用 Cloud Storage FUSE CSI 驱动访问云存储桶。
部署 RayJob#
下载执行 使用 Ray Train 微调 Pytorch 图像分类器 中记录的所有步骤的 RayJob。源代码 也在 KubeRay 仓库中。
curl -LO https://raw.githubusercontent.com/ray-project/kuberay/master/ray-operator/config/samples/pytorch-resnet-image-classifier/ray-job.pytorch-image-classifier.yaml
通过将 GCS_BUCKET
占位符替换为您之前创建的 Google Cloud Storage 存储桶来修改 RayJob。或者,您可以使用 sed
:
sed -i "s/GCS_BUCKET/$BUCKET/g" ray-job.pytorch-image-classifier.yaml
部署 RayJob:
kubectl create -f ray-job.pytorch-image-classifier.yaml
部署的 RayJob 包含以下配置,以启用分布式检查点到共享文件系统:
4 个 Ray 工作节点,每个节点配备一个 GPU。
所有 Ray 节点使用我们之前创建的
pytorch-distributed-training
服务账户。包含由
gcsfuse.csi.storage.gke.io
CSI 驱动程序管理的卷。挂载一个共享存储路径
/mnt/cluster_storage
,由您之前创建的GCS存储桶支持。
您可以通过注释配置 Pod,这允许对 GCSFuse 边车容器进行更精细的控制。更多详情请参见 指定 Pod 注释。
annotations:
gke-gcsfuse/volumes: "true"
gke-gcsfuse/cpu-limit: "0"
gke-gcsfuse/memory-limit: 5Gi
gke-gcsfuse/ephemeral-storage-limit: 10Gi
在定义 GCSFuse 容器卷时,您还可以指定挂载选项:
csi:
driver: gcsfuse.csi.storage.gke.io
volumeAttributes:
bucketName: GCS_BUCKET
mountOptions: "implicit-dirs,uid=1000,gid=100"
参见 挂载选项 以了解更多关于挂载选项的信息。
Ray 作业的日志应指示在 /mnt/cluster_storage
中使用共享远程文件系统和检查点目录。例如:
Training finished iteration 10 at 2024-04-29 10:22:08. Total running time: 1min 30s
╭─────────────────────────────────────────╮
│ Training result │
├─────────────────────────────────────────┤
│ checkpoint_dir_name checkpoint_000009 │
│ time_this_iter_s 6.47154 │
│ time_total_s 74.5547 │
│ training_iteration 10 │
│ acc 0.24183 │
│ loss 0.06882 │
╰─────────────────────────────────────────╯
Training saved a checkpoint for iteration 10 at: (local)/mnt/cluster_storage/finetune-resnet/TorchTrainer_cbb82_00000_0_2024-04-29_10-20-37/checkpoint_000009
检查检查点数据#
一旦 RayJob 完成,你可以使用像 gsutil 这样的工具来检查你的存储桶内容。
gsutil ls gs://my-ray-bucket/**
gs://my-ray-bucket/finetune-resnet/
gs://my-ray-bucket/finetune-resnet/.validate_storage_marker
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000007/
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000007/checkpoint.pt
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000008/
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000008/checkpoint.pt
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000009/
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000009/checkpoint.pt
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/error.pkl
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/error.txt
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/events.out.tfevents.1714436502.orch-image-classifier-nc2sq-raycluster-tdrfx-head-xzcl8
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/events.out.tfevents.1714436809.orch-image-classifier-zz4sj-raycluster-vn7kz-head-lwx8k
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/params.json
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/params.pkl
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/progress.csv
gs://my-ray-bucket/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/result.json
gs://my-ray-bucket/finetune-resnet/basic-variant-state-2024-04-29_17-21-29.json
gs://my-ray-bucket/finetune-resnet/basic-variant-state-2024-04-29_17-26-35.json
gs://my-ray-bucket/finetune-resnet/experiment_state-2024-04-29_17-21-29.json
gs://my-ray-bucket/finetune-resnet/experiment_state-2024-04-29_17-26-35.json
gs://my-ray-bucket/finetune-resnet/trainer.pkl
gs://my-ray-bucket/finetune-resnet/tuner.pkl
从检查点恢复#
在任务失败的情况下,您可以使用最新的检查点来恢复模型的训练。此示例配置 TorchTrainer
以自动从最新的检查点恢复:
experiment_path = os.path.expanduser("/mnt/cluster_storage/finetune-resnet")
if TorchTrainer.can_restore(experiment_path):
trainer = TorchTrainer.restore(experiment_path,
train_loop_per_worker=train_loop_per_worker,
train_loop_config=train_loop_config,
scaling_config=scaling_config,
run_config=run_config,
)
else:
trainer = TorchTrainer(
train_loop_per_worker=train_loop_per_worker,
train_loop_config=train_loop_config,
scaling_config=scaling_config,
run_config=run_config,
)
您可以通过重新部署相同的 RayJob 来验证自动检查点恢复:
kubectl create -f ray-job.pytorch-image-classifier.yaml
如果之前的任务成功,训练任务应该从 checkpoint_000009
目录恢复检查点状态,然后立即以0次迭代完成训练:
2024-04-29 15:51:32,528 INFO experiment_state.py:366 -- Trying to find and download experiment checkpoint at /mnt/cluster_storage/finetune-resnet
2024-04-29 15:51:32,651 INFO experiment_state.py:396 -- A remote experiment checkpoint was found and will be used to restore the previous experiment state.
2024-04-29 15:51:32,652 INFO tune_controller.py:404 -- Using the newest experiment state file found within the experiment directory: experiment_state-2024-04-29_15-43-40.json
View detailed results here: /mnt/cluster_storage/finetune-resnet
To visualize your results with TensorBoard, run: `tensorboard --logdir /home/ray/ray_results/finetune-resnet`
Result(
metrics={'loss': 0.070047477101968, 'acc': 0.23529411764705882},
path='/mnt/cluster_storage/finetune-resnet/TorchTrainer_ecc04_00000_0_2024-04-29_15-43-40',
filesystem='local',
checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/finetune-resnet/TorchTrainer_ecc04_00000_0_2024-04-29_15-43-40/checkpoint_000009)
)
如果之前的作业在较早的检查点失败,作业应从最后一个保存的检查点恢复并运行,直到 max_epochs=10
。例如,如果上次运行在第7个epoch失败,训练将自动使用 checkpoint_000006
恢复,并再运行3个迭代,直到第10个epoch:
(TorchTrainer pid=611, ip=10.108.2.65) Restored on 10.108.2.65 from checkpoint: Checkpoint(filesystem=local, path=/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000006)
(RayTrainWorker pid=671, ip=10.108.2.65) Setting up process group for: env:// [rank=0, world_size=4]
(TorchTrainer pid=611, ip=10.108.2.65) Started distributed worker processes:
(TorchTrainer pid=611, ip=10.108.2.65) - (ip=10.108.2.65, pid=671) world_rank=0, local_rank=0, node_rank=0
(TorchTrainer pid=611, ip=10.108.2.65) - (ip=10.108.1.83, pid=589) world_rank=1, local_rank=0, node_rank=1
(TorchTrainer pid=611, ip=10.108.2.65) - (ip=10.108.0.72, pid=590) world_rank=2, local_rank=0, node_rank=2
(TorchTrainer pid=611, ip=10.108.2.65) - (ip=10.108.3.76, pid=590) world_rank=3, local_rank=0, node_rank=3
(RayTrainWorker pid=589, ip=10.108.1.83) Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/ray/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
(RayTrainWorker pid=671, ip=10.108.2.65)
0%| | 0.00/97.8M [00:00<?, ?B/s]
(RayTrainWorker pid=671, ip=10.108.2.65)
22%|██▏ | 21.8M/97.8M [00:00<00:00, 229MB/s]
(RayTrainWorker pid=671, ip=10.108.2.65)
92%|█████████▏| 89.7M/97.8M [00:00<00:00, 327MB/s]
(RayTrainWorker pid=671, ip=10.108.2.65)
100%|██████████| 97.8M/97.8M [00:00<00:00, 316MB/s]
(RayTrainWorker pid=671, ip=10.108.2.65) Moving model to device: cuda:0
(RayTrainWorker pid=671, ip=10.108.2.65) Wrapping provided model in DistributedDataParallel.
(RayTrainWorker pid=671, ip=10.108.2.65) Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/ray/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)
(RayTrainWorker pid=590, ip=10.108.3.76)
0%| | 0.00/97.8M [00:00<?, ?B/s] [repeated 3x across cluster]
(RayTrainWorker pid=590, ip=10.108.0.72)
85%|████████▍ | 82.8M/97.8M [00:00<00:00, 256MB/s]
100%|██████████| 97.8M/97.8M [00:00<00:00, 231MB/s] [repeated 11x across cluster]
(RayTrainWorker pid=590, ip=10.108.3.76)
100%|██████████| 97.8M/97.8M [00:00<00:00, 238MB/s]
(RayTrainWorker pid=671, ip=10.108.2.65) Epoch 7-train Loss: 0.0903 Acc: 0.2418
(RayTrainWorker pid=671, ip=10.108.2.65) Epoch 7-val Loss: 0.0881 Acc: 0.2353
(RayTrainWorker pid=590, ip=10.108.0.72) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000007)
(RayTrainWorker pid=590, ip=10.108.0.72) Moving model to device: cuda:0 [repeated 3x across cluster]
(RayTrainWorker pid=590, ip=10.108.0.72) Wrapping provided model in DistributedDataParallel. [repeated 3x across cluster]
Training finished iteration 8 at 2024-04-29 17:27:29. Total running time: 54s
╭─────────────────────────────────────────╮
│ Training result │
├─────────────────────────────────────────┤
│ checkpoint_dir_name checkpoint_000007 │
│ time_this_iter_s 40.46113 │
│ time_total_s 95.00043 │
│ training_iteration 8 │
│ acc 0.23529 │
│ loss 0.08811 │
╰─────────────────────────────────────────╯
Training saved a checkpoint for iteration 8 at: (local)/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000007
(RayTrainWorker pid=671, ip=10.108.2.65) Epoch 8-train Loss: 0.0893 Acc: 0.2459
(RayTrainWorker pid=671, ip=10.108.2.65) Epoch 8-val Loss: 0.0859 Acc: 0.2353
(RayTrainWorker pid=589, ip=10.108.1.83) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000008) [repeated 4x across cluster]
Training finished iteration 9 at 2024-04-29 17:27:36. Total running time: 1min 1s
╭─────────────────────────────────────────╮
│ Training result │
├─────────────────────────────────────────┤
│ checkpoint_dir_name checkpoint_000008 │
│ time_this_iter_s 5.99923 │
│ time_total_s 100.99965 │
│ training_iteration 9 │
│ acc 0.23529 │
│ loss 0.08592 │
╰─────────────────────────────────────────╯
Training saved a checkpoint for iteration 9 at: (local)/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000008
2024-04-29 17:27:37,170 WARNING util.py:202 -- The `process_trial_save` operation took 0.540 s, which may be a performance bottleneck.
(RayTrainWorker pid=671, ip=10.108.2.65) Epoch 9-train Loss: 0.0866 Acc: 0.2377
(RayTrainWorker pid=671, ip=10.108.2.65) Epoch 9-val Loss: 0.0833 Acc: 0.2353
(RayTrainWorker pid=589, ip=10.108.1.83) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000009) [repeated 4x across cluster]
Training finished iteration 10 at 2024-04-29 17:27:43. Total running time: 1min 8s
╭─────────────────────────────────────────╮
│ Training result │
├─────────────────────────────────────────┤
│ checkpoint_dir_name checkpoint_000009 │
│ time_this_iter_s 6.71457 │
│ time_total_s 107.71422 │
│ training_iteration 10 │
│ acc 0.23529 │
│ loss 0.08333 │
╰─────────────────────────────────────────╯
Training saved a checkpoint for iteration 10 at: (local)/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000009
Training completed after 10 iterations at 2024-04-29 17:27:45. Total running time: 1min 9s
2024-04-29 17:27:46,236 WARNING experiment_state.py:323 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.
Result(
metrics={'loss': 0.08333033206416111, 'acc': 0.23529411764705882},
path='/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29',
filesystem='local',
checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000009)
)
(RayTrainWorker pid=590, ip=10.108.3.76) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/mnt/cluster_storage/finetune-resnet/TorchTrainer_96923_00000_0_2024-04-29_17-21-29/checkpoint_000009) [repeated 3x across cluster]