ray.train.torch.TorchTrainer#

class ray.train.torch.TorchTrainer(*args, **kwargs)[源代码]#

基类:DataParallelTrainer

一个用于数据并行 PyTorch 训练的训练器。

从高层次来看,这个训练器执行以下操作:

  1. 根据 scaling_config 定义启动多个工作进程。

  2. 在这些工作节点上根据 torch_config 设置一个分布式的 PyTorch 环境。

  3. 基于 dataset_config 导入输入 datasets

  4. 在所有工作节点上运行输入的 train_loop_per_worker(train_loop_config)

更多详情,请参阅:

示例

import os
import tempfile

import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel

import ray
from ray.train import Checkpoint, CheckpointConfig, RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer

# If using GPUs, set this to True.
use_gpu = False
# Number of processes to run training on.
num_workers = 4

# Define your network structure.
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.layer1 = nn.Linear(1, 32)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(32, 1)

    def forward(self, input):
        return self.layer2(self.relu(self.layer1(input)))

# Training loop.
def train_loop_per_worker(config):

    # Read configurations.
    lr = config["lr"]
    batch_size = config["batch_size"]
    num_epochs = config["num_epochs"]

    # Fetch training dataset.
    train_dataset_shard = ray.train.get_dataset_shard("train")

    # Instantiate and prepare model for training.
    model = NeuralNetwork()
    model = ray.train.torch.prepare_model(model)

    # Define loss and optimizer.
    loss_fn = nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    # Create data loader.
    dataloader = train_dataset_shard.iter_torch_batches(
        batch_size=batch_size, dtypes=torch.float
    )

    # Train multiple epochs.
    for epoch in range(num_epochs):

        # Train epoch.
        for batch in dataloader:
            output = model(batch["input"])
            loss = loss_fn(output, batch["label"])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Create checkpoint.
        base_model = (model.module
            if isinstance(model, DistributedDataParallel) else model)
        checkpoint_dir = tempfile.mkdtemp()
        torch.save(
            {"model_state_dict": base_model.state_dict()},
            os.path.join(checkpoint_dir, "model.pt"),
        )
        checkpoint = Checkpoint.from_directory(checkpoint_dir)

        # Report metrics and checkpoint.
        ray.train.report({"loss": loss.item()}, checkpoint=checkpoint)


# Define configurations.
train_loop_config = {"num_epochs": 20, "lr": 0.01, "batch_size": 32}
scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)
run_config = RunConfig(checkpoint_config=CheckpointConfig(num_to_keep=1))

# Define datasets.
train_dataset = ray.data.from_items(
    [{"input": [x], "label": [2 * x + 1]} for x in range(2000)]
)
datasets = {"train": train_dataset}

# Initialize the Trainer.
trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config=train_loop_config,
    scaling_config=scaling_config,
    run_config=run_config,
    datasets=datasets
)

# Train the model.
result = trainer.fit()

# Inspect the results.
final_loss = result.metrics["loss"]
参数:
  • train_loop_per_worker – 在每个工作节点上执行的训练函数。该函数可以不带参数,或者接受一个由定义 train_loop_config 设置的单个 Dict 参数。在这个函数中,你可以使用任何 Ray 训练循环工具

  • train_loop_config – 一个配置 Dict,作为参数传递给 train_loop_per_worker。这通常用于指定超参数。不建议通过 train_loop_config 传递大型数据集,这可能会引入大量的开销和序列化与反序列化过程中的未知问题。

  • torch_config – 用于设置 PyTorch 分布式后端的配置。如果设置为 None,将使用默认配置,其中 GPU 训练使用 NCCL,CPU 训练使用 Gloo。

  • scaling_config – 数据并行训练的配置方式。num_workers 决定了用于训练的Python进程数量,use_gpu 决定了每个进程是否应使用GPU。更多信息请参见 ScalingConfig

  • run_config – 训练运行的配置。更多信息请参见 RunConfig

  • datasets – 用于训练的 Ray 数据集。数据集按名称键入({name: dataset})。每个数据集可以通过在 train_loop_per_worker 中调用 ray.train.get_dataset_shard(name) 来访问。分片和额外配置可以通过传入 dataset_config 来完成。

  • dataset_config – 用于摄取输入 datasets 的配置。默认情况下,所有 Ray Dataset 在各个工作节点之间平均分配。更多详情请参见 DataConfig

  • resume_from_checkpoint – 用于恢复训练的检查点。可以通过在 train_loop_per_worker 中调用 ray.train.get_checkpoint() 来访问此检查点。

  • metadata – 应通过 ray.train.get_context().get_metadata() 和从该训练器保存的检查点的 checkpoint.get_metadata() 提供的字典。必须是 JSON 可序列化的。

方法

as_trainable

将自身转换为 tune.Trainable 类。

can_restore

检查给定目录是否包含一个可恢复的 Train 实验。

fit

运行训练。

get_dataset_config

返回此训练器的最终数据集配置的副本。

preprocess_datasets

已弃用。

restore

从之前中断/失败的运行中恢复一个 DataParallelTrainer。

setup

在调用 fit() 时执行初始设置的 Trainer。