ray.train.trainer.BaseTrainer#
- class ray.train.trainer.BaseTrainer(*args, **kwargs)[源代码]#
基类:
ABC
定义Ray分布式训练的接口。
注意:基类
BaseTrainer
不能直接实例化。只能使用它的一个子类。开发者须知:如果添加了新的训练器,请更新
air/_internal/usage.py
。训练器是如何工作的?
首先,初始化 Trainer。初始化过程在本地运行,因此不应在
__init__
中进行重量级设置。然后,当你调用
trainer.fit()
时,训练器会被序列化并复制到一个远程的 Ray 执行体。随后,以下方法会在远程执行体上按顺序调用。trainer.setup()
: 任何重量级的 Trainer 设置应在此处指定。trainer.training_loop()
: 执行主要的训练逻辑。调用
trainer.fit()
将返回一个ray.result.Result
对象,您可以在其中访问训练运行的指标,以及可能保存的任何检查点。
如何创建一个新的训练器?
子类化
ray.train.trainer.BaseTrainer
,并重写training_loop
方法,以及可选地重写setup
。import torch from ray.train.trainer import BaseTrainer from ray import train, tune class MyPytorchTrainer(BaseTrainer): def setup(self): self.model = torch.nn.Linear(1, 1) self.optimizer = torch.optim.SGD( self.model.parameters(), lr=0.1) def training_loop(self): # You can access any Trainer attributes directly in this method. # self.datasets["train"] has already been dataset = self.datasets["train"] torch_ds = dataset.iter_torch_batches(dtypes=torch.float) loss_fn = torch.nn.MSELoss() for epoch_idx in range(10): loss = 0 num_batches = 0 torch_ds = dataset.iter_torch_batches( dtypes=torch.float, batch_size=2 ) for batch in torch_ds: X = torch.unsqueeze(batch["x"], 1) y = torch.unsqueeze(batch["y"], 1) # Compute prediction error pred = self.model(X) batch_loss = loss_fn(pred, y) # Backpropagation self.optimizer.zero_grad() batch_loss.backward() self.optimizer.step() loss += batch_loss.item() num_batches += 1 loss /= num_batches # Use Tune functions to report intermediate # results. train.report({"loss": loss, "epoch": epoch_idx}) # Initialize the Trainer, and call Trainer.fit() import ray train_dataset = ray.data.from_items( [{"x": i, "y": i} for i in range(10)]) my_trainer = MyPytorchTrainer(datasets={"train": train_dataset}) result = my_trainer.fit()
- 参数:
scaling_config – 如何缩放训练的配置。
run_config – 训练运行的配置。
datasets – 用于训练的任何数据集。使用键 “train” 来表示哪个数据集是训练数据集。
metadata – 应通过
train.get_context().get_metadata()
和从该训练器保存的检查点的checkpoint.get_metadata()
提供的字典。必须是 JSON 可序列化的。resume_from_checkpoint – 用于从中恢复训练的检查点。
开发者API: 此API可能会在Ray的次要版本之间发生变化。
方法
将自身转换为
tune.Trainable
类。检查给定目录是否包含一个可恢复的 Train 实验。
运行训练。
已弃用。
从之前中断/失败的运行中恢复一个训练实验。
在调用 fit() 时执行初始设置的 Trainer。
由 fit() 调用的循环,用于运行训练并将结果报告给 Tune。