Tune 中的回调与指标指南#

如何在 Ray Tune 中使用回调?#

Ray Tune 支持在训练过程的不同时间调用的回调函数。回调函数可以作为参数传递给 RunConfig,由 Tuner 接收,并且您提供的子方法将被自动调用。

这个简单的回调函数每次收到结果时都会打印一个指标:

from ray import train, tune
from ray.train import RunConfig
from ray.tune import Callback


class MyCallback(Callback):
    def on_trial_result(self, iteration, trials, trial, result, **info):
        print(f"Got result: {result['metric']}")


def train_fn(config):
    for i in range(10):
        train.report({"metric": i})


tuner = tune.Tuner(
    train_fn,
    run_config=RunConfig(callbacks=[MyCallback()]))
tuner.fit()

更多详情和可用钩子,请 参阅 Ray Tune 回调的 API 文档

如何在 Tune 中使用日志指标?#

在函数和类训练API中,您可以记录任意值和指标:

def trainable(config):
    for i in range(num_epochs):
        ...
        train.report({"acc": accuracy, "metric_foo": random_metric_1, "bar": metric_2})

class Trainable(tune.Trainable):
    def step(self):
        ...
        # don't call report here!
        return dict(acc=accuracy, metric_foo=random_metric_1, bar=metric_2)

小技巧

请注意,train.report() 并不用于传输大量数据,如模型或数据集。这样做可能会导致大量开销,并显著减慢您的 Tune 运行速度。

哪些 Tune 指标会自动填充?#

Tune 有一个自动填充指标的概念。在训练期间,Tune 除了用户提供的值外,还会自动记录以下指标。所有这些都可以用作停止条件,或作为参数传递给试验调度器/搜索算法。

  • config: 超参数配置

  • date: 结果处理时的字符串格式日期和时间

  • done: 如果试验已完成则为 True,否则为 False

  • episodes_total: 总集数(用于 RLlib 可训练对象)

  • experiment_id: 唯一的实验ID

  • experiment_tag: 唯一的实验标签(包含参数值)

  • hostname: 工作节点的主机名

  • iterations_since_restore: 从检查点恢复工作程序后,train.report 被调用的次数

  • node_ip: 工作节点的宿主机IP

  • pid: 工作进程的进程ID (PID)

  • time_since_restore: 自从从检查点恢复以来的时间,单位为秒。

  • time_this_iter_s: 当前训练迭代的时间,以秒为单位(即对可训练函数的一次调用,或在类API中对 _train() 的一次调用)。

  • time_total_s: 总运行时间,单位为秒。

  • timestamp: 结果被处理时的时间戳

  • timesteps_since_restore: 自从从检查点恢复以来的时间步数

  • timesteps_total: 总时间步数

  • training_iteration: train.report() 被调用的次数

  • trial_id: 唯一的试验ID

所有这些指标都可以在 Trial.last_result 字典中看到。