使用 PyTorch Lightning 开始分布式训练#

本教程详细介绍了将现有的 PyTorch Lightning 脚本转换为使用 Ray Train 的过程。

学习如何:

  1. 配置 Lightning Trainer 以便它与 Ray 一起运行分布式,并在正确的 CPU 或 GPU 设备上运行。

  2. 配置 训练函数 以报告指标并保存检查点。

  3. 配置 扩展 和训练作业的CPU或GPU资源需求。

  4. 使用 TorchTrainer 启动一个分布式训练任务。

快速入门#

作为参考,最终代码如下:

from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig

def train_func():
    # Your PyTorch Lightning training code here.

scaling_config = ScalingConfig(num_workers=2, use_gpu=True)
trainer = TorchTrainer(train_func, scaling_config=scaling_config)
result = trainer.fit()
  1. train_func 是每个分布式训练工作节点上执行的Python代码。

  2. ScalingConfig 定义了分布式训练工作者的数量以及是否使用GPU。

  3. TorchTrainer 启动分布式训练任务。

比较使用和不使用 Ray Train 的 PyTorch Lightning 训练脚本。

import torch
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose
from torch.utils.data import DataLoader
import lightning.pytorch as pl

# Model, Loss, Optimizer
class ImageClassifier(pl.LightningModule):
    def __init__(self):
        super(ImageClassifier, self).__init__()
        self.model = resnet18(num_classes=10)
        self.model.conv1 = torch.nn.Conv2d(
            1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
        )
        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        outputs = self.forward(x)
        loss = self.criterion(outputs, y)
        self.log("loss", loss, on_step=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.001)

# Data
transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
train_data = FashionMNIST(root='./data', train=True, download=True, transform=transform)
train_dataloader = DataLoader(train_data, batch_size=128, shuffle=True)

# Training
model = ImageClassifier()
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model, train_dataloaders=train_dataloader)
import os
import tempfile

import torch
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose
import lightning.pytorch as pl

import ray.train.lightning
from ray.train.torch import TorchTrainer

# Model, Loss, Optimizer
class ImageClassifier(pl.LightningModule):
    def __init__(self):
        super(ImageClassifier, self).__init__()
        self.model = resnet18(num_classes=10)
        self.model.conv1 = torch.nn.Conv2d(
            1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
        )
        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        outputs = self.forward(x)
        loss = self.criterion(outputs, y)
        self.log("loss", loss, on_step=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.001)


def train_func():
    # Data
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    data_dir = os.path.join(tempfile.gettempdir(), "data")
    train_data = FashionMNIST(root=data_dir, train=True, download=True, transform=transform)
    train_dataloader = DataLoader(train_data, batch_size=128, shuffle=True)

    # Training
    model = ImageClassifier()
    # [1] Configure PyTorch Lightning Trainer.
    trainer = pl.Trainer(
        max_epochs=10,
        devices="auto",
        accelerator="auto",
        strategy=ray.train.lightning.RayDDPStrategy(),
        plugins=[ray.train.lightning.RayLightningEnvironment()],
        callbacks=[ray.train.lightning.RayTrainReportCallback()],
        # [1a] Optionally, disable the default checkpointing behavior
        # in favor of the `RayTrainReportCallback` above.
        enable_checkpointing=False,
    )
    trainer = ray.train.lightning.prepare_trainer(trainer)
    trainer.fit(model, train_dataloaders=train_dataloader)

# [2] Configure scaling and resource requirements.
scaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=True)

# [3] Launch distributed training job.
trainer = TorchTrainer(
    train_func,
    scaling_config=scaling_config,
    # [3a] If running in a multi-node cluster, this is where you
    # should configure the run's persistent storage that is accessible
    # across all worker nodes.
    # run_config=ray.train.RunConfig(storage_path="s3://..."),
)
result: ray.train.Result = trainer.fit()

# [4] Load the trained model.
with result.checkpoint.as_directory() as checkpoint_dir:
    model = ImageClassifier.load_from_checkpoint(
        os.path.join(
            checkpoint_dir,
            ray.train.lightning.RayTrainReportCallback.CHECKPOINT_NAME,
        ),
    )

设置一个训练函数#

首先,更新你的训练代码以支持分布式训练。开始时,将你的代码封装在一个 训练函数 中:

def train_func():
    # Your model training code here.
    ...

每个分布式训练的工作者执行此函数。

你也可以通过 Trainer 的 train_loop_configtrain_func 的输入参数指定为一个字典。例如:

def train_func(config):
    lr = config["lr"]
    num_epochs = config["num_epochs"]

config = {"lr": 1e-4, "num_epochs": 10}
trainer = ray.train.torch.TorchTrainer(train_func, train_loop_config=config, ...)

警告

避免通过 train_loop_config 传递大数据对象以减少序列化和反序列化的开销。相反,建议直接在 train_func 中初始化大型对象(例如数据集、模型)。

 def load_dataset():
     # Return a large in-memory dataset
     ...

 def load_model():
     # Return a large in-memory model instance
     ...

-config = {"data": load_dataset(), "model": load_model()}

 def train_func(config):
-    data = config["data"]
-    model = config["model"]

+    data = load_dataset()
+    model = load_model()
     ...

 trainer = ray.train.torch.TorchTrainer(train_func, train_loop_config=config, ...)

Ray Train 在每个工作节点上设置分布式进程组。你只需要对你的 Lightning Trainer 定义做少量修改。

 import lightning.pytorch as pl
-from pl.strategies import DDPStrategy
-from pl.plugins.environments import LightningEnvironment
+import ray.train.lightning

 def train_func():
     ...
     model = MyLightningModule(...)
     datamodule = MyLightningDataModule(...)

     trainer = pl.Trainer(
-        devices=[0, 1, 2, 3],
-        strategy=DDPStrategy(),
-        plugins=[LightningEnvironment()],
+        devices="auto",
+        accelerator="auto",
+        strategy=ray.train.lightning.RayDDPStrategy(),
+        plugins=[ray.train.lightning.RayLightningEnvironment()]
     )
+    trainer = ray.train.lightning.prepare_trainer(trainer)

     trainer.fit(model, datamodule=datamodule)

以下各节讨论每个更改。

配置分布式策略#

Ray Train 为 Lightning 提供了几种子类化的分布式策略。这些策略保留了与其基础策略类相同的参数列表。在内部,它们配置了根设备和分布式采样器参数。

 import lightning.pytorch as pl
-from pl.strategies import DDPStrategy
+import ray.train.lightning

 def train_func():
     ...
     trainer = pl.Trainer(
         ...
-        strategy=DDPStrategy(),
+        strategy=ray.train.lightning.RayDDPStrategy(),
         ...
     )
     ...

配置 Ray 集群环境插件#

Ray Train 还提供了一个 RayLightningEnvironment 类,作为 Ray 集群的规范。这个实用类配置了工作者的本地、全局、节点等级和世界大小。

 import lightning.pytorch as pl
-from pl.plugins.environments import LightningEnvironment
+import ray.train.lightning

 def train_func():
     ...
     trainer = pl.Trainer(
         ...
-        plugins=[LightningEnvironment()],
+        plugins=[ray.train.lightning.RayLightningEnvironment()],
         ...
     )
     ...

配置并行设备#

此外,Ray TorchTrainer 已经为您配置了正确的 CUDA_VISIBLE_DEVICES。应始终通过设置 devices="auto"accelerator="auto" 来使用所有可用的GPU。

 import lightning.pytorch as pl

 def train_func():
     ...
     trainer = pl.Trainer(
         ...
-        devices=[0,1,2,3],
+        devices="auto",
+        accelerator="auto",
         ...
     )
     ...

报告检查点和指标#

要持久化检查点并监控训练进度,请将 ray.train.lightning.RayTrainReportCallback 实用回调添加到您的 Trainer 中。

 import lightning.pytorch as pl
 from ray.train.lightning import RayTrainReportCallback

 def train_func():
     ...
     trainer = pl.Trainer(
         ...
-        callbacks=[...],
+        callbacks=[..., RayTrainReportCallback()],
     )
     ...

将指标和检查点报告给 Ray Train 使您能够支持 容错训练超参数优化。请注意,ray.train.lightning.RayTrainReportCallback 类仅提供了一个简单的实现,并且可以 进一步定制

准备你的 Lightning Trainer#

最后,将您的 Lightning Trainer 传递给 prepare_trainer() 以验证您的配置。

 import lightning.pytorch as pl
 import ray.train.lightning

 def train_func():
     ...
     trainer = pl.Trainer(...)
+    trainer = ray.train.lightning.prepare_trainer(trainer)
     ...

配置规模和GPU#

在你的训练函数之外,创建一个 ScalingConfig 对象来配置:

  1. num_workers - 分布式训练工作进程的数量。

  2. use_gpu - 每个工作线程是否应使用GPU(或CPU)。

from ray.train import ScalingConfig
scaling_config = ScalingConfig(num_workers=2, use_gpu=True)

更多详情,请参阅 配置规模和GPU

配置持久存储#

创建一个 RunConfig 对象来指定保存结果(包括检查点和工件)的路径。

from ray.train import RunConfig

# Local path (/some/local/path/unique_run_name)
run_config = RunConfig(storage_path="/some/local/path", name="unique_run_name")

# Shared cloud storage URI (s3://bucket/unique_run_name)
run_config = RunConfig(storage_path="s3://bucket", name="unique_run_name")

# Shared NFS path (/mnt/nfs/unique_run_name)
run_config = RunConfig(storage_path="/mnt/nfs", name="unique_run_name")

警告

指定一个*共享存储位置*(如云存储或NFS)对于单节点集群是*可选的*,但对于多节点集群是**必需的**。对于多节点集群,使用本地路径在检查点期间会:ref:引发错误 <multinode-local-storage-warning>

更多详情,请参阅 持久存储指南

启动训练任务#

综上所述,你现在可以使用 TorchTrainer 启动一个分布式训练任务。

from ray.train.torch import TorchTrainer

trainer = TorchTrainer(
    train_func, scaling_config=scaling_config, run_config=run_config
)
result = trainer.fit()

访问训练结果#

训练完成后,会返回一个 Result 对象,其中包含有关训练运行的信息,包括训练期间报告的指标和检查点。

result.metrics     # The metrics reported during training.
result.checkpoint  # The latest checkpoint reported during training.
result.path        # The path where logs are stored.
result.error       # The exception that was raised, if training failed.

更多使用示例,请参阅 检查训练结果

下一步#

在将您的 PyTorch Lightning 训练脚本转换为使用 Ray Train 之后:

  • 参见 用户指南 以了解更多关于如何执行特定任务的信息。

  • 浏览 示例 以获取如何使用 Ray Train 的端到端示例。

  • 有关本教程中的类和方法的更多详细信息,请参阅 API 参考

版本兼容性#

Ray Train 与 pytorch_lightning 版本 1.6.52.1.2 进行了测试。为了完全兼容,请使用 pytorch_lightning>=1.6.5 。较早的版本虽未禁止使用,但可能会导致意外问题。如果您遇到任何兼容性问题,请考虑升级您的 PyTorch Lightning 版本或 提交问题

备注

如果你使用的是 Lightning 2.x,请使用导入路径 lightning.pytorch.xxx 而不是 pytorch_lightning.xxx

LightningTrainer 迁移指南#

Ray 2.4 引入了 LightningTrainer,并公开了一个 LightningConfigBuilder 来定义 pl.LightningModulepl.Trainer 的配置。

然后实例化模型和训练器对象,并在黑箱中运行预定义的训练函数。

这个版本的 LightningTrainer API 限制了您管理训练功能的能力。

Ray 2.7 引入了新统一的 TorchTrainer API,该 API 提供了增强的透明度、灵活性和简洁性。此 API 更符合标准的 PyTorch Lightning 脚本,确保用户对其原生 Lightning 代码有更好的控制。

from ray.train.lightning import LightningConfigBuilder, LightningTrainer

config_builder = LightningConfigBuilder()
# [1] Collect model configs
config_builder.module(cls=MyLightningModule, lr=1e-3, feature_dim=128)

# [2] Collect checkpointing configs
config_builder.checkpointing(monitor="val_accuracy", mode="max", save_top_k=3)

# [3] Collect pl.Trainer configs
config_builder.trainer(
    max_epochs=10,
    accelerator="gpu",
    log_every_n_steps=100,
)

# [4] Build datasets on the head node
datamodule = MyLightningDataModule(batch_size=32)
config_builder.fit_params(datamodule=datamodule)

# [5] Execute the internal training function in a black box
ray_trainer = LightningTrainer(
    lightning_config=config_builder.build(),
    scaling_config=ScalingConfig(num_workers=4, use_gpu=True),
    run_config=RunConfig(
        checkpoint_config=CheckpointConfig(
            num_to_keep=3,
            checkpoint_score_attribute="val_accuracy",
            checkpoint_score_order="max",
        ),
    )
)
result = ray_trainer.fit()

# [6] Load the trained model from an opaque Lightning-specific checkpoint.
lightning_checkpoint = result.checkpoint
model = lightning_checkpoint.get_model(MyLightningModule)
import os

import lightning.pytorch as pl

import ray.train
from ray.train.torch import TorchTrainer
from ray.train.lightning import (
    RayDDPStrategy,
    RayLightningEnvironment,
    RayTrainReportCallback,
    prepare_trainer
)

def train_func():
    # [1] Create a Lightning model
    model = MyLightningModule(lr=1e-3, feature_dim=128)

    # [2] Report Checkpoint with callback
    ckpt_report_callback = RayTrainReportCallback()

    # [3] Create a Lighting Trainer
    trainer = pl.Trainer(
        max_epochs=10,
        log_every_n_steps=100,
        # New configurations below
        devices="auto",
        accelerator="auto",
        strategy=RayDDPStrategy(),
        plugins=[RayLightningEnvironment()],
        callbacks=[ckpt_report_callback],
    )

    # Validate your Lightning trainer configuration
    trainer = prepare_trainer(trainer)

    # [4] Build your datasets on each worker
    datamodule = MyLightningDataModule(batch_size=32)
    trainer.fit(model, datamodule=datamodule)

# [5] Explicitly define and run the training function
ray_trainer = TorchTrainer(
    train_func,
    scaling_config=ray.train.ScalingConfig(num_workers=4, use_gpu=True),
    run_config=ray.train.RunConfig(
        checkpoint_config=ray.train.CheckpointConfig(
            num_to_keep=3,
            checkpoint_score_attribute="val_accuracy",
            checkpoint_score_order="max",
        ),
    )
)
result = ray_trainer.fit()

# [6] Load the trained model from a simplified checkpoint interface.
checkpoint: ray.train.Checkpoint = result.checkpoint
with checkpoint.as_directory() as checkpoint_dir:
    print("Checkpoint contents:", os.listdir(checkpoint_dir))
    checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.ckpt")
    model = MyLightningModule.load_from_checkpoint(checkpoint_path)