监控和日志指标#

Ray Train 提供了一个 API,用于从训练函数(在分布式工作节点上运行)向 Trainer``(执行您的 Python 脚本的地方)报告中间结果和检查点,通过调用 ``train.report(metrics)。结果将从分布式工作节点收集,并传递给驱动程序以进行记录和显示。

警告

只有来自排名为0的工作者的结果会被使用。然而,为了确保一致性,必须在每个工作者上调用 train.report()。如果你想从多个工作者中聚合结果,请参见 如何从不同的工作者获取并汇总结果?

报告的主要用途是在每个训练周期结束时记录指标(如准确率、损失等)。

from ray import train

def train_func():
    ...
    for i in range(num_epochs):
        result = model.train(...)
        train.report({"result": result})

在 PyTorch Lightning 中,我们使用回调来调用 train.report()

from ray import train
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback

class MyRayTrainReportCallback(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        metrics = trainer.callback_metrics
        metrics = {k: v.item() for k, v in metrics.items()}

        train.report(metrics=metrics)

def train_func_per_worker():
    ...
    trainer = pl.Trainer(
        # ...
        callbacks=[MyRayTrainReportCallback()]
    )
    trainer.fit()

如何从不同的工作者获取并汇总结果?#

在实际应用中,除了准确率和损失,您可能还希望计算其他优化指标:召回率、精确率、Fbeta等。您可能还希望从多个工作节点收集指标。虽然Ray Train目前仅报告来自rank 0工作节点的指标,但您可以使用第三方库或机器学习框架的分布式原语来报告来自多个工作节点的指标。

Ray Train 原生支持 TorchMetrics,它为分布式、可扩展的 PyTorch 模型提供了一系列机器学习指标。

以下是一个报告所有工作者的聚合R2分数以及平均训练和验证损失的示例。


# First, pip install torchmetrics
# This code is tested with torchmetrics==0.7.3 and torch==1.12.1

import ray.train.torch
from ray import train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer

import torch
import torch.nn as nn
import torchmetrics
from torch.optim import Adam
import numpy as np


def train_func(config):
    n = 100
    # create a toy dataset
    X = torch.Tensor(np.random.normal(0, 1, size=(n, 4)))
    X_valid = torch.Tensor(np.random.normal(0, 1, size=(n, 4)))
    Y = torch.Tensor(np.random.uniform(0, 1, size=(n, 1)))
    Y_valid = torch.Tensor(np.random.uniform(0, 1, size=(n, 1)))
    # toy neural network : 1-layer
    # wrap the model in DDP
    model = ray.train.torch.prepare_model(nn.Linear(4, 1))
    criterion = nn.MSELoss()

    mape = torchmetrics.MeanAbsolutePercentageError()
    # for averaging loss
    mean_valid_loss = torchmetrics.MeanMetric()

    optimizer = Adam(model.parameters(), lr=3e-4)
    for epoch in range(config["num_epochs"]):
        model.train()
        y = model.forward(X)

        # compute loss
        loss = criterion(y, Y)

        # back-propagate loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # evaluate
        model.eval()
        with torch.no_grad():
            pred = model(X_valid)
            valid_loss = criterion(pred, Y_valid)
            # save loss in aggregator
            mean_valid_loss(valid_loss)
            mape(pred, Y_valid)

        # collect all metrics
        # use .item() to obtain a value that can be reported
        valid_loss = valid_loss.item()
        mape_collected = mape.compute().item()
        mean_valid_loss_collected = mean_valid_loss.compute().item()

        train.report(
            {
                "mape_collected": mape_collected,
                "valid_loss": valid_loss,
                "mean_valid_loss_collected": mean_valid_loss_collected,
            }
        )

        # reset for next epoch
        mape.reset()
        mean_valid_loss.reset()


trainer = TorchTrainer(
    train_func,
    train_loop_config={"num_epochs": 5},
    scaling_config=ScalingConfig(num_workers=2),
)
result = trainer.fit()
print(result.metrics["valid_loss"], result.metrics["mean_valid_loss_collected"])
# 0.5109779238700867 0.5512474775314331