训练一个 Pytorch Lightning 图像分类器#
本示例介绍了如何使用 Ray Train TorchTrainer
训练一个 Pytorch Lightning 模块。它演示了如何在 MNIST 数据集上使用分布式数据并行训练一个基本的神经网络。
!pip install "torchmetrics>=0.9" "pytorch_lightning>=1.6"
import os
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from filelock import FileLock
from torch.utils.data import DataLoader, random_split, Subset
from torchmetrics import Accuracy
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning import trainer
from pytorch_lightning.loggers.csv_logs import CSVLogger
准备数据集和模块#
Pytorch Lightning Trainer 接受 torch.utils.data.DataLoader
或 pl.LightningDataModule
作为数据输入。您可以在 Ray Train 中无缝使用它们,无需进行任何更改。
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, batch_size=100):
super().__init__()
self.data_dir = os.getcwd()
self.batch_size = batch_size
self.transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
def setup(self, stage=None):
with FileLock(f"{self.data_dir}.lock"):
mnist = MNIST(
self.data_dir, train=True, download=True, transform=self.transform
)
# 将数据划分为训练集和验证集
self.mnist_train, self.mnist_val = random_split(mnist, [55000, 5000])
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=4)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=4)
def test_dataloader(self):
with FileLock(f"{self.data_dir}.lock"):
self.mnist_test = MNIST(
self.data_dir, train=False, download=True, transform=self.transform
)
return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=4)
接下来,定义一个简单的多层感知器作为pl.LightningModule
的子类。
class MNISTClassifier(pl.LightningModule):
def __init__(self, lr=1e-3, feature_dim=128):
torch.manual_seed(421)
super(MNISTClassifier, self).__init__()
self.save_hyperparameters()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28 * 28, feature_dim),
nn.ReLU(),
nn.Linear(feature_dim, 10),
nn.ReLU(),
)
self.lr = lr
self.accuracy = Accuracy(task="multiclass", num_classes=10, top_k=1)
self.eval_loss = []
self.eval_accuracy = []
self.test_accuracy = []
pl.seed_everything(888)
def forward(self, x):
x = x.view(-1, 28 * 28)
x = self.linear_relu_stack(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = torch.nn.functional.cross_entropy(y_hat, y)
self.log("train_loss", loss)
return loss
def validation_step(self, val_batch, batch_idx):
loss, acc = self._shared_eval(val_batch)
self.log("val_accuracy", acc)
self.eval_loss.append(loss)
self.eval_accuracy.append(acc)
return {"val_loss": loss, "val_accuracy": acc}
def test_step(self, test_batch, batch_idx):
loss, acc = self._shared_eval(test_batch)
self.test_accuracy.append(acc)
self.log("test_accuracy", acc, sync_dist=True, on_epoch=True)
return {"test_loss": loss, "test_accuracy": acc}
def _shared_eval(self, batch):
x, y = batch
logits = self.forward(x)
loss = F.nll_loss(logits, y)
acc = self.accuracy(logits, y)
return loss, acc
def on_validation_epoch_end(self):
avg_loss = torch.stack(self.eval_loss).mean()
avg_acc = torch.stack(self.eval_accuracy).mean()
self.log("val_loss", avg_loss, sync_dist=True)
self.log("val_accuracy", avg_acc, sync_dist=True)
self.eval_loss.clear()
self.eval_accuracy.clear()
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
您不需要修改PyTorch Lightning模型或数据模块的定义。
定义训练函数#
该代码为每个工作节点定义了一个 训练函数。将训练函数与原始的 PyTorch Lightning 代码进行比较,注意三处主要区别:
分布式策略:使用
RayDDPStrategy
。集群环境:使用
RayLightningEnvironment
。并行设备:始终设置为
devices="auto"
,以使用TorchTrainer
配置的所有可用设备。
有关更多信息,请参见 开始使用 PyTorch Lightning。
对于检查点报告,Ray Train 提供了一个最小的 RayTrainReportCallback
类,该类在每个训练周期结束时报告指标和检查点。对于更复杂的检查点逻辑,可以实现自定义回调。请参见 保存和加载检查点。
use_gpu = True # 如果希望在没有GPU的情况下运行,请设置为False。
num_workers = 4
import pytorch_lightning as pl
from ray.train import RunConfig, ScalingConfig, CheckpointConfig
from ray.train.torch import TorchTrainer
from ray.train.lightning import (
RayDDPStrategy,
RayLightningEnvironment,
RayTrainReportCallback,
prepare_trainer,
)
def train_func_per_worker():
model = MNISTClassifier(lr=1e-3, feature_dim=128)
datamodule = MNISTDataModule(batch_size=128)
trainer = pl.Trainer(
devices="auto",
strategy=RayDDPStrategy(),
plugins=[RayLightningEnvironment()],
callbacks=[RayTrainReportCallback()],
max_epochs=10,
accelerator="gpu" if use_gpu else "cpu",
log_every_n_steps=100,
logger=CSVLogger("logs"),
)
trainer = prepare_trainer(trainer)
# 训练模型
trainer.fit(model, datamodule=datamodule)
# 测试数据集的评估
trainer.test(model, datamodule=datamodule)
现在把所有内容放在一起:
scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)
run_config = RunConfig(
name="ptl-mnist-example",
storage_path="/tmp/ray_results",
checkpoint_config=CheckpointConfig(
num_to_keep=3,
checkpoint_score_attribute="val_accuracy",
checkpoint_score_order="max",
),
)
trainer = TorchTrainer(
train_func_per_worker,
scaling_config=scaling_config,
run_config=run_config,
)
现在适配你的训练器:
result = trainer.fit()
Tune Status
Current time: | 2023-08-07 23:41:11 |
Running for: | 00:00:39.80 |
Memory: | 24.2/186.6 GiB |
System Info
Using FIFO scheduling algorithm.Logical resource usage: 1.0/48 CPUs, 4.0/4 GPUs
Trial Status
Trial name | status | loc | iter | total time (s) | train_loss | val_accuracy | val_loss |
---|---|---|---|---|---|---|---|
TorchTrainer_78346_00000 | TERMINATED | 10.0.6.244:120026 | 10 | 29.0221 | 0.0315938 | 0.970002 | -12.3466 |
(TorchTrainer pid=120026) Starting distributed worker processes: ['120176 (10.0.6.244)', '120177 (10.0.6.244)', '120178 (10.0.6.244)', '120179 (10.0.6.244)']
(RayTrainWorker pid=120176) Setting up process group for: env:// [rank=0, world_size=4]
(RayTrainWorker pid=120176) [rank: 0] Global seed set to 888
(RayTrainWorker pid=120176) GPU available: True (cuda), used: True
(RayTrainWorker pid=120176) TPU available: False, using: 0 TPU cores
(RayTrainWorker pid=120176) IPU available: False, using: 0 IPUs
(RayTrainWorker pid=120176) HPU available: False, using: 0 HPUs
(RayTrainWorker pid=120178) Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
(RayTrainWorker pid=120178) Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /tmp/ray_results/ptl-mnist-example/TorchTrainer_78346_00000_0_2023-08-07_23-40-31/rank_2/MNIST/raw/train-images-idx3-ubyte.gz
0%| | 0/9912422 [00:00<?, ?it/s]
100%|██████████| 9912422/9912422 [00:00<00:00, 94562894.32it/s]
9%|▉ | 917504/9912422 [00:00<00:00, 9166590.91it/s]
100%|██████████| 9912422/9912422 [00:00<00:00, 115619443.32it/s]
(RayTrainWorker pid=120179) Extracting /tmp/ray_results/ptl-mnist-example/TorchTrainer_78346_00000_0_2023-08-07_23-40-31/rank_3/MNIST/raw/train-images-idx3-ubyte.gz to /tmp/ray_results/ptl-mnist-example/TorchTrainer_78346_00000_0_2023-08-07_23-40-31/rank_3/MNIST/raw
(RayTrainWorker pid=120176)
(RayTrainWorker pid=120177) Missing logger folder: logs/lightning_logs
(RayTrainWorker pid=120176) LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
(RayTrainWorker pid=120176) | Name | Type | Params
(RayTrainWorker pid=120176) ---------------------------------------------------------
(RayTrainWorker pid=120176) 0 | linear_relu_stack | Sequential | 101 K
(RayTrainWorker pid=120176) 1 | accuracy | MulticlassAccuracy | 0
(RayTrainWorker pid=120176) ---------------------------------------------------------
(RayTrainWorker pid=120176) 101 K Trainable params
(RayTrainWorker pid=120176) 0 Non-trainable params
(RayTrainWorker pid=120176) 101 K Total params
(RayTrainWorker pid=120176) 0.407 Total estimated model params size (MB)
Sanity Checking: 0it [00:00, ?it/s])
Sanity Checking DataLoader 0: 0%| | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 2.69it/s]
(RayTrainWorker pid=120176) /mnt/cluster_storage/pypi/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:432: PossibleUserWarning: It is recommended to use `self.log('val_accuracy', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
(RayTrainWorker pid=120176) warning_cache.warn(
(RayTrainWorker pid=120179) [rank: 3] Global seed set to 888 [repeated 7x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)
Epoch 0: 0%| | 0/108 [00:00<?, ?it/s]
Epoch 0: 12%|█▏ | 13/108 [00:00<00:02, 39.35it/s, v_num=0]
Epoch 0: 25%|██▌ | 27/108 [00:00<00:01, 59.26it/s, v_num=0]
Epoch 0: 26%|██▌ | 28/108 [00:00<00:01, 61.03it/s, v_num=0]
Epoch 0: 27%|██▋ | 29/108 [00:00<00:01, 62.76it/s, v_num=0]
Epoch 0: 42%|████▏ | 45/108 [00:00<00:00, 81.02it/s, v_num=0]
Epoch 0: 53%|█████▎ | 57/108 [00:00<00:00, 86.01it/s, v_num=0]
Epoch 0: 64%|██████▍ | 69/108 [00:00<00:00, 88.63it/s, v_num=0]
Epoch 0: 81%|████████ | 87/108 [00:00<00:00, 98.04it/s, v_num=0]
Epoch 0: 81%|████████▏ | 88/108 [00:00<00:00, 98.69it/s, v_num=0]
Epoch 0: 82%|████████▏ | 89/108 [00:00<00:00, 99.34it/s, v_num=0]
Epoch 0: 96%|█████████▋| 104/108 [00:00<00:00, 104.14it/s, v_num=0]
Epoch 0: 97%|█████████▋| 105/108 [00:01<00:00, 104.71it/s, v_num=0]
Epoch 0: 98%|█████████▊| 106/108 [00:01<00:00, 105.22it/s, v_num=0]
Epoch 0: 100%|██████████| 108/108 [00:01<00:00, 105.79it/s, v_num=0]
Validation: 0it [00:00, ?it/s]76)
Validation: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 10%|█ | 1/10 [00:00<00:00, 171.69it/s]
Validation DataLoader 0: 20%|██ | 2/10 [00:00<00:00, 200.99it/s]
Validation DataLoader 0: 30%|███ | 3/10 [00:00<00:00, 221.66it/s]
Validation DataLoader 0: 40%|████ | 4/10 [00:00<00:00, 215.50it/s]
Validation DataLoader 0: 50%|█████ | 5/10 [00:00<00:00, 194.14it/s]
Validation DataLoader 0: 60%|██████ | 6/10 [00:00<00:00, 205.63it/s]
Validation DataLoader 0: 70%|███████ | 7/10 [00:00<00:00, 215.27it/s]
Validation DataLoader 0: 80%|████████ | 8/10 [00:00<00:00, 216.26it/s]
Validation DataLoader 0: 90%|█████████ | 9/10 [00:00<00:00, 198.67it/s]
Validation DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 205.79it/s]
Epoch 0: 100%|██████████| 108/108 [00:01<00:00, 79.84it/s, v_num=0]
Epoch 1: 0%| | 0/108 [00:00<?, ?it/s, v_num=0]
Epoch 1: 11%|█ | 12/108 [00:00<00:02, 32.36it/s, v_num=0]
Epoch 1: 23%|██▎ | 25/108 [00:00<00:01, 50.16it/s, v_num=0]
Epoch 1: 37%|███▋ | 40/108 [00:00<00:01, 65.95it/s, v_num=0]
Epoch 1: 38%|███▊ | 41/108 [00:00<00:00, 67.05it/s, v_num=0]
Epoch 1: 50%|█████ | 54/108 [00:00<00:00, 75.52it/s, v_num=0]
Epoch 1: 51%|█████ | 55/108 [00:00<00:00, 76.40it/s, v_num=0]
Epoch 1: 62%|██████▏ | 67/108 [00:00<00:00, 81.72it/s, v_num=0]
(RayTrainWorker pid=120178) Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz [repeated 15x across cluster]
Epoch 1: 77%|███████▋ | 83/108 [00:00<00:00, 89.48it/s, v_num=0]
(RayTrainWorker pid=120178) Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /tmp/ray_results/ptl-mnist-example/TorchTrainer_78346_00000_0_2023-08-07_23-40-31/rank_2/MNIST/raw/t10k-labels-idx1-ubyte.gz [repeated 15x across cluster]
Epoch 1: 78%|███████▊ | 84/108 [00:00<00:00, 89.21it/s, v_num=0]
Epoch 1: 91%|█████████ | 98/108 [00:01<00:00, 93.27it/s, v_num=0]
(RayTrainWorker pid=120178) Extracting /tmp/ray_results/ptl-mnist-example/TorchTrainer_78346_00000_0_2023-08-07_23-40-31/rank_2/MNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/ray_results/ptl-mnist-example/TorchTrainer_78346_00000_0_2023-08-07_23-40-31/rank_2/MNIST/raw [repeated 15x across cluster]
Epoch 1: 92%|█████████▏| 99/108 [00:01<00:00, 93.94it/s, v_num=0]
Epoch 1: 93%|█████████▎| 100/108 [00:01<00:00, 94.57it/s, v_num=0]
Epoch 1: 100%|██████████| 108/108 [00:01<00:00, 98.06it/s, v_num=0]
Validation: 0it [00:00, ?it/s]76)
Validation: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 10%|█ | 1/10 [00:00<00:00, 320.27it/s]
Validation DataLoader 0: 20%|██ | 2/10 [00:00<00:00, 291.99it/s]
(RayTrainWorker pid=120176) [repeated 19x across cluster]
Validation DataLoader 0: 30%|███ | 3/10 [00:00<00:00, 291.61it/s]
Validation DataLoader 0: 40%|████ | 4/10 [00:00<00:00, 268.90it/s]
Validation DataLoader 0: 50%|█████ | 5/10 [00:00<00:00, 290.07it/s]
Validation DataLoader 0: 60%|██████ | 6/10 [00:00<00:00, 293.52it/s]
Validation DataLoader 0: 70%|███████ | 7/10 [00:00<00:00, 299.70it/s]
Validation DataLoader 0: 80%|████████ | 8/10 [00:00<00:00, 304.80it/s]
Validation DataLoader 0: 90%|█████████ | 9/10 [00:00<00:00, 310.16it/s]
Validation DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 303.63it/s]
Epoch 1: 100%|██████████| 108/108 [00:01<00:00, 76.12it/s, v_num=0]
Epoch 2: 0%| | 0/108 [00:00<?, ?it/s, v_num=0]
Epoch 2: 7%|▋ | 8/108 [00:00<00:03, 25.23it/s, v_num=0]
Epoch 2: 16%|█▌ | 17/108 [00:00<00:02, 39.73it/s, v_num=0]
Epoch 2: 17%|█▋ | 18/108 [00:00<00:02, 41.60it/s, v_num=0]
Epoch 2: 18%|█▊ | 19/108 [00:00<00:02, 43.49it/s, v_num=0]
Epoch 2: 18%|█▊ | 19/108 [00:00<00:02, 43.46it/s, v_num=0]
Epoch 2: 19%|█▊ | 20/108 [00:00<00:01, 45.27it/s, v_num=0]
Epoch 2: 27%|██▋ | 29/108 [00:00<00:01, 53.08it/s, v_num=0]
Epoch 2: 42%|████▏ | 45/108 [00:00<00:00, 69.12it/s, v_num=0]
Epoch 2: 43%|████▎ | 46/108 [00:00<00:00, 70.31it/s, v_num=0]
Epoch 2: 44%|████▎ | 47/108 [00:00<00:00, 71.51it/s, v_num=0]
Epoch 2: 44%|████▍ | 48/108 [00:00<00:00, 72.71it/s, v_num=0]
Epoch 2: 59%|█████▉ | 64/108 [00:00<00:00, 83.97it/s, v_num=0]
Epoch 2: 75%|███████▌ | 81/108 [00:00<00:00, 93.77it/s, v_num=0]
Epoch 2: 90%|████████▉ | 97/108 [00:00<00:00, 99.35it/s, v_num=0]
Epoch 2: 100%|██████████| 108/108 [00:01<00:00, 101.71it/s, v_num=0]
Validation: 0it [00:00, ?it/s]76)
Validation: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 10%|█ | 1/10 [00:00<00:00, 212.13it/s]
Validation DataLoader 0: 20%|██ | 2/10 [00:00<00:00, 184.45it/s]
Validation DataLoader 0: 30%|███ | 3/10 [00:00<00:00, 228.42it/s]
Validation DataLoader 0: 40%|████ | 4/10 [00:00<00:00, 225.00it/s]
Validation DataLoader 0: 50%|█████ | 5/10 [00:00<00:00, 250.65it/s]
Validation DataLoader 0: 60%|██████ | 6/10 [00:00<00:00, 251.36it/s]
Validation DataLoader 0: 70%|███████ | 7/10 [00:00<00:00, 268.85it/s]
Validation DataLoader 0: 80%|████████ | 8/10 [00:00<00:00, 256.15it/s]
Validation DataLoader 0: 90%|█████████ | 9/10 [00:00<00:00, 269.87it/s]
Epoch 2: 100%|██████████| 108/108 [00:01<00:00, 77.52it/s, v_num=0] it/s]
Epoch 3: 0%| | 0/108 [00:00<?, ?it/s, v_num=0]
Epoch 3: 8%|▊ | 9/108 [00:00<00:03, 25.68it/s, v_num=0]
Epoch 3: 9%|▉ | 10/108 [00:00<00:03, 28.26it/s, v_num=0]
Epoch 3: 20%|██ | 22/108 [00:00<00:01, 48.10it/s, v_num=0]
Epoch 3: 21%|██▏ | 23/108 [00:00<00:01, 49.73it/s, v_num=0]
Epoch 3: 22%|██▏ | 24/108 [00:00<00:01, 51.34it/s, v_num=0]
Epoch 3: 23%|██▎ | 25/108 [00:00<00:01, 52.98it/s, v_num=0]
Epoch 3: 37%|███▋ | 40/108 [00:00<00:00, 69.67it/s, v_num=0]
Epoch 3: 51%|█████ | 55/108 [00:00<00:00, 80.93it/s, v_num=0]
Epoch 3: 64%|██████▍ | 69/108 [00:00<00:00, 87.15it/s, v_num=0]
Epoch 3: 65%|██████▍ | 70/108 [00:00<00:00, 88.04it/s, v_num=0]
Epoch 3: 66%|██████▌ | 71/108 [00:00<00:00, 88.92it/s, v_num=0]
Epoch 3: 77%|███████▋ | 83/108 [00:00<00:00, 92.62it/s, v_num=0]
Epoch 3: 86%|████████▌ | 93/108 [00:01<00:00, 92.33it/s, v_num=0]
Epoch 3: 87%|████████▋ | 94/108 [00:01<00:00, 92.93it/s, v_num=0]
Epoch 3: 88%|████████▊ | 95/108 [00:01<00:00, 93.61it/s, v_num=0]
Epoch 3: 100%|██████████| 108/108 [00:01<00:00, 97.43it/s, v_num=0]
Validation: 0it [00:00, ?it/s]76)
Validation: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 10%|█ | 1/10 [00:00<00:00, 308.50it/s]
Validation DataLoader 0: 20%|██ | 2/10 [00:00<00:00, 344.87it/s]
Validation DataLoader 0: 30%|███ | 3/10 [00:00<00:00, 375.98it/s]
Validation DataLoader 0: 40%|████ | 4/10 [00:00<00:00, 335.26it/s]
Validation DataLoader 0: 50%|█████ | 5/10 [00:00<00:00, 327.34it/s]
Validation DataLoader 0: 60%|██████ | 6/10 [00:00<00:00, 317.66it/s]
Validation DataLoader 0: 70%|███████ | 7/10 [00:00<00:00, 332.79it/s]
Validation DataLoader 0: 80%|████████ | 8/10 [00:00<00:00, 188.14it/s]
Validation DataLoader 0: 90%|█████████ | 9/10 [00:00<00:00, 201.21it/s]
Epoch 3: 100%|██████████| 108/108 [00:01<00:00, 75.94it/s, v_num=0]6it/s]
Epoch 4: 0%| | 0/108 [00:00<?, ?it/s, v_num=0]
Epoch 4: 10%|█ | 11/108 [00:00<00:03, 30.09it/s, v_num=0]
Epoch 4: 20%|██ | 22/108 [00:00<00:01, 46.96it/s, v_num=0]
Epoch 4: 21%|██▏ | 23/108 [00:00<00:01, 47.88it/s, v_num=0]
Epoch 4: 35%|███▌ | 38/108 [00:00<00:01, 65.26it/s, v_num=0]
Epoch 4: 36%|███▌ | 39/108 [00:00<00:01, 65.73it/s, v_num=0]
Epoch 4: 53%|█████▎ | 57/108 [00:00<00:00, 81.51it/s, v_num=0]
Epoch 4: 54%|█████▎ | 58/108 [00:00<00:00, 82.56it/s, v_num=0]
Epoch 4: 68%|██████▊ | 73/108 [00:00<00:00, 89.69it/s, v_num=0]
Epoch 4: 69%|██████▊ | 74/108 [00:00<00:00, 90.53it/s, v_num=0]
Epoch 4: 83%|████████▎ | 90/108 [00:00<00:00, 98.32it/s, v_num=0]
Epoch 4: 98%|█████████▊| 106/108 [00:01<00:00, 103.12it/s, v_num=0]
Epoch 4: 100%|██████████| 108/108 [00:01<00:00, 103.78it/s, v_num=0]
Validation: 0it [00:00, ?it/s]76)
Validation: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 10%|█ | 1/10 [00:00<00:00, 268.49it/s]
Validation DataLoader 0: 20%|██ | 2/10 [00:00<00:00, 298.62it/s]
Validation DataLoader 0: 30%|███ | 3/10 [00:00<00:00, 282.88it/s]
Validation DataLoader 0: 40%|████ | 4/10 [00:00<00:00, 256.50it/s]
Validation DataLoader 0: 50%|█████ | 5/10 [00:00<00:00, 276.28it/s]
Validation DataLoader 0: 60%|██████ | 6/10 [00:00<00:00, 268.05it/s]
Validation DataLoader 0: 70%|███████ | 7/10 [00:00<00:00, 276.18it/s]
Validation DataLoader 0: 80%|████████ | 8/10 [00:00<00:00, 290.08it/s]
Validation DataLoader 0: 90%|█████████ | 9/10 [00:00<00:00, 261.92it/s]
Validation DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 274.00it/s]
Epoch 4: 100%|██████████| 108/108 [00:01<00:00, 78.54it/s, v_num=0]
Epoch 5: 0%| | 0/108 [00:00<?, ?it/s, v_num=0]
Epoch 5: 5%|▍ | 5/108 [00:00<00:06, 15.52it/s, v_num=0]
(RayTrainWorker pid=120176) [repeated 9x across cluster]
Epoch 5: 17%|█▋ | 18/108 [00:00<00:02, 42.53it/s, v_num=0]
Epoch 5: 26%|██▌ | 28/108 [00:00<00:01, 52.36it/s, v_num=0]
Epoch 5: 27%|██▋ | 29/108 [00:00<00:01, 53.91it/s, v_num=0]
Epoch 5: 28%|██▊ | 30/108 [00:00<00:01, 55.45it/s, v_num=0]
Epoch 5: 29%|██▊ | 31/108 [00:00<00:01, 56.96it/s, v_num=0]
Epoch 5: 37%|███▋ | 40/108 [00:00<00:01, 61.48it/s, v_num=0]
Epoch 5: 38%|███▊ | 41/108 [00:00<00:01, 62.61it/s, v_num=0]
Epoch 5: 39%|███▉ | 42/108 [00:00<00:01, 63.79it/s, v_num=0]
Epoch 5: 40%|███▉ | 43/108 [00:00<00:01, 64.96it/s, v_num=0]
Epoch 5: 48%|████▊ | 52/108 [00:00<00:00, 67.96it/s, v_num=0]
Epoch 5: 49%|████▉ | 53/108 [00:00<00:00, 68.88it/s, v_num=0]
Epoch 5: 50%|█████ | 54/108 [00:00<00:00, 69.77it/s, v_num=0]
Epoch 5: 62%|██████▏ | 67/108 [00:00<00:00, 76.43it/s, v_num=0]
Epoch 5: 78%|███████▊ | 84/108 [00:00<00:00, 85.56it/s, v_num=0]
Epoch 5: 79%|███████▊ | 85/108 [00:00<00:00, 86.17it/s, v_num=0]
Epoch 5: 93%|█████████▎| 100/108 [00:01<00:00, 92.27it/s, v_num=0]
Epoch 5: 100%|██████████| 108/108 [00:01<00:00, 94.81it/s, v_num=0]
Validation: 0it [00:00, ?it/s]76)
Validation: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 10%|█ | 1/10 [00:00<00:00, 255.91it/s]
Validation DataLoader 0: 20%|██ | 2/10 [00:00<00:00, 206.50it/s]
Validation DataLoader 0: 30%|███ | 3/10 [00:00<00:00, 214.91it/s]
Validation DataLoader 0: 40%|████ | 4/10 [00:00<00:00, 236.35it/s]
Validation DataLoader 0: 50%|█████ | 5/10 [00:00<00:00, 240.05it/s]
Validation DataLoader 0: 60%|██████ | 6/10 [00:00<00:00, 234.90it/s]
Validation DataLoader 0: 70%|███████ | 7/10 [00:00<00:00, 240.78it/s]
Validation DataLoader 0: 80%|████████ | 8/10 [00:00<00:00, 252.00it/s]
Validation DataLoader 0: 90%|█████████ | 9/10 [00:00<00:00, 255.18it/s]
Epoch 5: 100%|██████████| 108/108 [00:01<00:00, 72.36it/s, v_num=0]1it/s]
Epoch 6: 0%| | 0/108 [00:00<?, ?it/s, v_num=0]
Epoch 6: 15%|█▍ | 16/108 [00:00<00:02, 44.27it/s, v_num=0]
Epoch 6: 25%|██▌ | 27/108 [00:00<00:01, 57.21it/s, v_num=0]
Epoch 6: 38%|███▊ | 41/108 [00:00<00:00, 70.86it/s, v_num=0]
Epoch 6: 39%|███▉ | 42/108 [00:00<00:00, 71.82it/s, v_num=0]
Epoch 6: 55%|█████▍ | 59/108 [00:00<00:00, 85.97it/s, v_num=0]
Epoch 6: 68%|██████▊ | 73/108 [00:00<00:00, 91.53it/s, v_num=0]
Epoch 6: 81%|████████ | 87/108 [00:00<00:00, 96.88it/s, v_num=0]
Epoch 6: 92%|█████████▏| 99/108 [00:00<00:00, 99.33it/s, v_num=0]
Epoch 6: 93%|█████████▎| 100/108 [00:01<00:00, 98.66it/s, v_num=0]
Epoch 6: 94%|█████████▎| 101/108 [00:01<00:00, 99.34it/s, v_num=0]
Epoch 6: 100%|██████████| 108/108 [00:01<00:00, 102.79it/s, v_num=0]
Validation: 0it [00:00, ?it/s]76)
Validation: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 10%|█ | 1/10 [00:00<00:00, 197.51it/s]
Validation DataLoader 0: 20%|██ | 2/10 [00:00<00:00, 143.68it/s]
Validation DataLoader 0: 30%|███ | 3/10 [00:00<00:00, 156.17it/s]
Validation DataLoader 0: 40%|████ | 4/10 [00:00<00:00, 180.52it/s]
Validation DataLoader 0: 50%|█████ | 5/10 [00:00<00:00, 205.25it/s]
Validation DataLoader 0: 60%|██████ | 6/10 [00:00<00:00, 212.20it/s]
Validation DataLoader 0: 70%|███████ | 7/10 [00:00<00:00, 195.64it/s]
Validation DataLoader 0: 80%|████████ | 8/10 [00:00<00:00, 211.21it/s]
Validation DataLoader 0: 90%|█████████ | 9/10 [00:00<00:00, 225.13it/s]
Epoch 6: 100%|██████████| 108/108 [00:01<00:00, 76.04it/s, v_num=0] it/s]
Epoch 7: 0%| | 0/108 [00:00<?, ?it/s, v_num=0]
Epoch 7: 11%|█ | 12/108 [00:00<00:02, 33.31it/s, v_num=0]
Epoch 7: 20%|██ | 22/108 [00:00<00:01, 46.90it/s, v_num=0]
Epoch 7: 22%|██▏ | 24/108 [00:00<00:01, 50.49it/s, v_num=0]
Epoch 7: 31%|███▏ | 34/108 [00:00<00:01, 58.20it/s, v_num=0]
Epoch 7: 32%|███▏ | 35/108 [00:00<00:01, 59.59it/s, v_num=0]
Epoch 7: 33%|███▎ | 36/108 [00:00<00:01, 60.97it/s, v_num=0]
Epoch 7: 48%|████▊ | 52/108 [00:00<00:00, 74.69it/s, v_num=0]
Epoch 7: 64%|██████▍ | 69/108 [00:00<00:00, 85.96it/s, v_num=0]
Epoch 7: 80%|███████▉ | 86/108 [00:00<00:00, 94.41it/s, v_num=0]
Epoch 7: 81%|████████▏ | 88/108 [00:00<00:00, 95.91it/s, v_num=0]
Epoch 7: 97%|█████████▋| 105/108 [00:01<00:00, 102.61it/s, v_num=0]
Epoch 7: 100%|██████████| 108/108 [00:01<00:00, 103.00it/s, v_num=0]
Validation: 0it [00:00, ?it/s]76)
Validation: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 10%|█ | 1/10 [00:00<00:00, 215.46it/s]
Validation DataLoader 0: 20%|██ | 2/10 [00:00<00:00, 246.46it/s]
Validation DataLoader 0: 30%|███ | 3/10 [00:00<00:00, 264.39it/s]
Validation DataLoader 0: 40%|████ | 4/10 [00:00<00:00, 256.84it/s]
Validation DataLoader 0: 50%|█████ | 5/10 [00:00<00:00, 218.46it/s]
Validation DataLoader 0: 60%|██████ | 6/10 [00:00<00:00, 230.90it/s]
Validation DataLoader 0: 70%|███████ | 7/10 [00:00<00:00, 243.53it/s]
Validation DataLoader 0: 80%|████████ | 8/10 [00:00<00:00, 253.83it/s]
Validation DataLoader 0: 90%|█████████ | 9/10 [00:00<00:00, 249.22it/s]
Epoch 7: 100%|██████████| 108/108 [00:01<00:00, 78.36it/s, v_num=0] it/s]
Epoch 8: 0%| | 0/108 [00:00<?, ?it/s, v_num=0]
Epoch 8: 7%|▋ | 8/108 [00:00<00:03, 25.72it/s, v_num=0]
Epoch 8: 19%|█▊ | 20/108 [00:00<00:01, 47.54it/s, v_num=0]
Epoch 8: 31%|███ | 33/108 [00:00<00:01, 62.61it/s, v_num=0]
(RayTrainWorker pid=120176) [repeated 7x across cluster]
Epoch 8: 44%|████▎ | 47/108 [00:00<00:00, 74.45it/s, v_num=0]
Epoch 8: 55%|█████▍ | 59/108 [00:00<00:00, 81.84it/s, v_num=0]
Epoch 8: 56%|█████▌ | 60/108 [00:00<00:00, 80.82it/s, v_num=0]
Epoch 8: 57%|█████▋ | 62/108 [00:00<00:00, 82.74it/s, v_num=0]
Epoch 8: 58%|█████▊ | 63/108 [00:00<00:00, 83.69it/s, v_num=0]
Epoch 8: 70%|███████ | 76/108 [00:00<00:00, 88.60it/s, v_num=0]
Epoch 8: 85%|████████▌ | 92/108 [00:00<00:00, 96.53it/s, v_num=0]
Epoch 8: 86%|████████▌ | 93/108 [00:00<00:00, 96.21it/s, v_num=0]
Epoch 8: 87%|████████▋ | 94/108 [00:00<00:00, 96.72it/s, v_num=0]
Epoch 8: 88%|████████▊ | 95/108 [00:00<00:00, 97.32it/s, v_num=0]
Epoch 8: 89%|████████▉ | 96/108 [00:00<00:00, 98.03it/s, v_num=0]
Epoch 8: 100%|██████████| 108/108 [00:01<00:00, 102.15it/s, v_num=0]
Validation: 0it [00:00, ?it/s]76)
Validation: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 10%|█ | 1/10 [00:00<00:00, 228.96it/s]
Validation DataLoader 0: 20%|██ | 2/10 [00:00<00:00, 220.63it/s]
Validation DataLoader 0: 30%|███ | 3/10 [00:00<00:00, 220.41it/s]
Validation DataLoader 0: 40%|████ | 4/10 [00:00<00:00, 208.74it/s]
Validation DataLoader 0: 50%|█████ | 5/10 [00:00<00:00, 221.74it/s]
Validation DataLoader 0: 60%|██████ | 6/10 [00:00<00:00, 243.64it/s]
Validation DataLoader 0: 70%|███████ | 7/10 [00:00<00:00, 253.60it/s]
Validation DataLoader 0: 80%|████████ | 8/10 [00:00<00:00, 254.93it/s]
Validation DataLoader 0: 90%|█████████ | 9/10 [00:00<00:00, 207.23it/s]
Epoch 8: 100%|██████████| 108/108 [00:01<00:00, 78.28it/s, v_num=0] it/s]
Epoch 9: 0%| | 0/108 [00:00<?, ?it/s, v_num=0]
Epoch 9: 11%|█ | 12/108 [00:00<00:02, 33.03it/s, v_num=0]
Epoch 9: 21%|██▏ | 23/108 [00:00<00:01, 48.82it/s, v_num=0]
Epoch 9: 31%|███ | 33/108 [00:00<00:01, 58.62it/s, v_num=0]
Epoch 9: 31%|███▏ | 34/108 [00:00<00:01, 58.61it/s, v_num=0]
Epoch 9: 32%|███▏ | 35/108 [00:00<00:01, 59.89it/s, v_num=0]
Epoch 9: 33%|███▎ | 36/108 [00:00<00:01, 61.11it/s, v_num=0]
Epoch 9: 46%|████▋ | 50/108 [00:00<00:00, 71.95it/s, v_num=0]
Epoch 9: 61%|██████ | 66/108 [00:00<00:00, 82.62it/s, v_num=0]
Epoch 9: 70%|███████ | 76/108 [00:00<00:00, 83.77it/s, v_num=0]
Epoch 9: 71%|███████▏ | 77/108 [00:00<00:00, 84.54it/s, v_num=0]
Epoch 9: 72%|███████▏ | 78/108 [00:00<00:00, 85.33it/s, v_num=0]
Epoch 9: 88%|████████▊ | 95/108 [00:01<00:00, 93.18it/s, v_num=0]
Epoch 9: 100%|██████████| 108/108 [00:01<00:00, 98.27it/s, v_num=0]
Validation: 0it [00:00, ?it/s]76)
Validation: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 10%|█ | 1/10 [00:00<00:00, 305.42it/s]
Validation DataLoader 0: 20%|██ | 2/10 [00:00<00:00, 337.39it/s]
Validation DataLoader 0: 30%|███ | 3/10 [00:00<00:00, 368.65it/s]
Validation DataLoader 0: 40%|████ | 4/10 [00:00<00:00, 361.22it/s]
Validation DataLoader 0: 50%|█████ | 5/10 [00:00<00:00, 250.96it/s]
Validation DataLoader 0: 60%|██████ | 6/10 [00:00<00:00, 271.98it/s]
Validation DataLoader 0: 70%|███████ | 7/10 [00:00<00:00, 289.64it/s]
Validation DataLoader 0: 80%|████████ | 8/10 [00:00<00:00, 304.16it/s]
Validation DataLoader 0: 90%|█████████ | 9/10 [00:00<00:00, 184.87it/s]
Validation DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 196.99it/s]
Epoch 9: 100%|██████████| 108/108 [00:01<00:00, 74.63it/s, v_num=0]
(RayTrainWorker pid=120176) `Trainer.fit` stopped: `max_epochs=10` reached.
100%|██████████| 4542/4542 [00:00<00:00, 48474627.91it/s] [repeated 14x across cluster]
100%|██████████| 9912422/9912422 [00:00<00:00, 90032420.31it/s] [repeated 2x across cluster]
(RayTrainWorker pid=120176) [repeated 5x across cluster]
(RayTrainWorker pid=120178) Missing logger folder: logs/lightning_logs [repeated 2x across cluster]
(RayTrainWorker pid=120179) LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3] [repeated 3x across cluster]
(RayTrainWorker pid=120176) [rank: 0] Global seed set to 888
Epoch 9: 100%|██████████| 108/108 [00:01<00:00, 66.61it/s, v_num=0]
(RayTrainWorker pid=120176) /mnt/cluster_storage/pypi/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:225: PossibleUserWarning: Using `DistributedSampler` with the dataloaders. During `trainer.test()`, it is recommended to use `Trainer(devices=1, num_nodes=1)` to ensure each sample/batch gets evaluated exactly once. Otherwise, multi-device settings use `DistributedSampler` that replicates some samples to make sure all devices have same batch size in case of uneven inputs.
(RayTrainWorker pid=120176) rank_zero_warn(
Testing DataLoader 0: 25%|██▌ | 5/20 [00:00<00:00, 146.57it/s]
Testing DataLoader 0: 100%|██████████| 20/20 [00:00<00:00, 163.98it/s]
Testing DataLoader 0: 100%|██████████| 20/20 [00:00<00:00, 125.34it/s]
(RayTrainWorker pid=120176) ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
(RayTrainWorker pid=120176) ┃ Test metric ┃ DataLoader 0 ┃
(RayTrainWorker pid=120176) ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
(RayTrainWorker pid=120176) │ test_accuracy │ 0.9740999937057495 │
(RayTrainWorker pid=120176) └───────────────────────────┴───────────────────────────┘
2023-08-07 23:41:11,072 INFO tune.py:1145 -- Total run time: 39.92 seconds (39.80 seconds for the tuning loop).
检查训练结果和检查点#
result
Result(
metrics={'train_loss': 0.03159375861287117, 'val_accuracy': 0.9700015783309937, 'val_loss': -12.346583366394043, 'epoch': 9, 'step': 1080},
path='/tmp/ray_results/ptl-mnist-example/TorchTrainer_78346_00000_0_2023-08-07_23-40-31',
checkpoint=LegacyTorchCheckpoint(local_path=/tmp/ray_results/ptl-mnist-example/TorchTrainer_78346_00000_0_2023-08-07_23-40-31/checkpoint_000009)
)
print("Validation Accuracy: ", result.metrics["val_accuracy"])
print("Trial Directory: ", result.path)
print(sorted(os.listdir(result.path)))
Validation Accuracy: 0.9700015783309937
Trial Directory: /tmp/ray_results/ptl-mnist-example/TorchTrainer_78346_00000_0_2023-08-07_23-40-31
['checkpoint_000007', 'checkpoint_000008', 'checkpoint_000009', 'events.out.tfevents.1691476838.ip-10-0-6-244', 'params.json', 'params.pkl', 'progress.csv', 'rank_0', 'rank_0.lock', 'rank_1', 'rank_1.lock', 'rank_2', 'rank_2.lock', 'rank_3', 'rank_3.lock', 'result.json']
Ray Train 在试验目录中保存了三个检查点(checkpoint_000007
、checkpoint_000008
、checkpoint_000009
)。以下代码从拟合结果中检索最新的检查点并将其加载回模型中。
如果您丢失了内存中的结果对象,可以从检查点文件中恢复模型。检查点路径为:/tmp/ray_results/ptl-mnist-example/TorchTrainer_eb925_00000_0_2023-08-07_23-15-06/checkpoint_000009/checkpoint.ckpt
。
checkpoint = result.checkpoint
with checkpoint.as_directory() as ckpt_dir:
best_model = MNISTClassifier.load_from_checkpoint(f"{ckpt_dir}/checkpoint.ckpt")
best_model
Global seed set to 888
MNISTClassifier(
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=128, bias=True)
(1): ReLU()
(2): Linear(in_features=128, out_features=10, bias=True)
(3): ReLU()
)
(accuracy): MulticlassAccuracy()
)
另请参阅#
使用 PyTorch Lightning 入门,了解如何使用 Ray Train 和 PyTorch Lightning 的教程
Ray Train 示例,获取更多用例