ray.train.trainer.BaseTrainer#

class ray.train.trainer.BaseTrainer(*args, **kwargs)[源代码]#

基类:ABC

定义Ray分布式训练的接口。

注意:基类 BaseTrainer 不能直接实例化。只能使用它的一个子类。

开发者须知:如果添加了新的训练器,请更新 air/_internal/usage.py

训练器是如何工作的?

  • 首先,初始化 Trainer。初始化过程在本地运行,因此不应在 __init__ 中进行重量级设置。

  • 然后,当你调用 trainer.fit() 时,训练器会被序列化并复制到一个远程的 Ray 执行体。随后,以下方法会在远程执行体上按顺序调用。

  • trainer.setup(): 任何重量级的 Trainer 设置应在此处指定。

  • trainer.training_loop(): 执行主要的训练逻辑。

  • 调用 trainer.fit() 将返回一个 ray.result.Result 对象,您可以在其中访问训练运行的指标,以及可能保存的任何检查点。

如何创建一个新的训练器?

子类化 ray.train.trainer.BaseTrainer ,并重写 training_loop 方法,以及可选地重写 setup

import torch

from ray.train.trainer import BaseTrainer
from ray import train, tune


class MyPytorchTrainer(BaseTrainer):
    def setup(self):
        self.model = torch.nn.Linear(1, 1)
        self.optimizer = torch.optim.SGD(
            self.model.parameters(), lr=0.1)

    def training_loop(self):
        # You can access any Trainer attributes directly in this method.
        # self.datasets["train"] has already been
        dataset = self.datasets["train"]

        torch_ds = dataset.iter_torch_batches(dtypes=torch.float)
        loss_fn = torch.nn.MSELoss()

        for epoch_idx in range(10):
            loss = 0
            num_batches = 0
            torch_ds = dataset.iter_torch_batches(
                dtypes=torch.float, batch_size=2
            )
            for batch in torch_ds:
                X = torch.unsqueeze(batch["x"], 1)
                y = torch.unsqueeze(batch["y"], 1)
                # Compute prediction error
                pred = self.model(X)
                batch_loss = loss_fn(pred, y)

                # Backpropagation
                self.optimizer.zero_grad()
                batch_loss.backward()
                self.optimizer.step()

                loss += batch_loss.item()
                num_batches += 1
            loss /= num_batches

            # Use Tune functions to report intermediate
            # results.
            train.report({"loss": loss, "epoch": epoch_idx})


# Initialize the Trainer, and call Trainer.fit()
import ray
train_dataset = ray.data.from_items(
    [{"x": i, "y": i} for i in range(10)])
my_trainer = MyPytorchTrainer(datasets={"train": train_dataset})
result = my_trainer.fit()
参数:
  • scaling_config – 如何缩放训练的配置。

  • run_config – 训练运行的配置。

  • datasets – 用于训练的任何数据集。使用键 “train” 来表示哪个数据集是训练数据集。

  • metadata – 应通过 train.get_context().get_metadata() 和从该训练器保存的检查点的 checkpoint.get_metadata() 提供的字典。必须是 JSON 可序列化的。

  • resume_from_checkpoint – 用于从中恢复训练的检查点。

开发者API: 此API可能会在Ray的次要版本之间发生变化。

方法

as_trainable

将自身转换为 tune.Trainable 类。

can_restore

检查给定目录是否包含一个可恢复的 Train 实验。

fit

运行训练。

preprocess_datasets

已弃用。

restore

从之前中断/失败的运行中恢复一个训练实验。

setup

在调用 fit() 时执行初始设置的 Trainer。

training_loop

由 fit() 调用的循环,用于运行训练并将结果报告给 Tune。