ray.train.tensorflow.TensorflowTrainer#

class ray.train.tensorflow.TensorflowTrainer(*args, **kwargs)[源代码]#

基类:DataParallelTrainer

用于数据并行 TensorFlow 训练的训练器。

这个训练器在多个 Ray 角色上运行 train_loop_per_worker 函数。这些角色已经配置了必要的 TensorFlow 进程组,用于分布式 TensorFlow 训练。

train_loop_per_worker 函数预期接收 0 或 1 个参数:

def train_loop_per_worker():
    ...
def train_loop_per_worker(config: Dict):
    ...

如果 train_loop_per_worker 接受一个参数,那么 train_loop_config 将被作为参数传递进去。如果你想将 train_loop_config 中的值作为超参数进行调整,这会很有用。

如果 datasets 字典包含一个训练数据集(由 “train” 键表示),那么它将被拆分为多个数据集分片,这些分片可以通过 train_loop_per_worker 中的 ray.train.get_dataset_shard("train") 访问。所有其他数据集将不会被拆分,ray.train.get_dataset_shard(...) 将返回整个数据集。

train_loop_per_worker 函数内部,你可以使用任何 Ray Train 循环方法

警告

Ray will not automatically set any environment variables or configuration related to local parallelism / threading aside from “OMP_NUM_THREADS”. If you desire greater control over TensorFlow threading, use the tf.config.threading module (eg. tf.config.threading.set_inter_op_parallelism_threads(num_cpus)) at the beginning of your train_loop_per_worker function.

from ray import train

def train_loop_per_worker():
    # Report intermediate results for callbacks or logging and
    # checkpoint data.
    train.report(...)

    # Returns dict of last saved checkpoint.
    train.get_checkpoint()

    # Returns the Dataset shard for the given key.
    train.get_dataset_shard("my_dataset")

    # Returns the total number of workers executing training.
    train.get_context().get_world_size()

    # Returns the rank of this worker.
    train.get_context().get_world_rank()

    # Returns the rank of the worker on the current node.
    train.get_context().get_local_rank()

train_loop_per_worker 返回的任何内容都将被丢弃,不会在任何地方使用或保存。

要将模型保存以供 TensorflowPredictor 使用,您必须将其保存在传递给 train.report()Checkpoint 中的 “model” kwarg 下。

示例:

import os
import tempfile
import tensorflow as tf

import ray
from ray import train
from ray.train import Checkpoint, ScalingConfig
from ray.train.tensorflow import TensorflowTrainer

def build_model():
    # toy neural network : 1-layer
    return tf.keras.Sequential(
        [tf.keras.layers.Dense(
            1, activation="linear", input_shape=(1,))]
    )

def train_loop_per_worker(config):
    dataset_shard = train.get_dataset_shard("train")
    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
    with strategy.scope():
        model = build_model()
        model.compile(
            optimizer="Adam", loss="mean_squared_error", metrics=["mse"])

    tf_dataset = dataset_shard.to_tf(
        feature_columns="x",
        label_columns="y",
        batch_size=1
    )
    for epoch in range(config["num_epochs"]):
        model.fit(tf_dataset)

        # Create checkpoint.
        checkpoint_dir = tempfile.mkdtemp()
        model.save_weights(
            os.path.join(checkpoint_dir, "my_checkpoint")
        )
        checkpoint = Checkpoint.from_directory(checkpoint_dir)

        train.report(
            {},
            checkpoint=checkpoint,
        )

train_dataset = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)])
trainer = TensorflowTrainer(
    train_loop_per_worker=train_loop_per_worker,
    scaling_config=ScalingConfig(num_workers=3, use_gpu=True),
    datasets={"train": train_dataset},
    train_loop_config={"num_epochs": 2},
)
result = trainer.fit()
参数:
  • train_loop_per_worker – 要执行的训练函数。这可以不带参数或接受一个 config 字典。

  • train_loop_config – 如果 train_loop_per_worker 接受参数,则传递给它的配置。

  • tensorflow_config – 用于设置 TensorFlow 后端的配置。如果设置为 None,则使用默认配置。这取代了 DataParallelTrainerbackend_config 参数。

  • scaling_config – 数据并行训练的配置方式。

  • dataset_config – 数据集摄取的配置。

  • run_config – 训练运行的配置。

  • datasets – 用于训练的任何数据集。使用键 “train” 来表示哪个数据集是训练数据集。

  • resume_from_checkpoint – 用于从中恢复训练的检查点。

  • metadata – 应通过 ray.train.get_context().get_metadata() 和从该训练器保存的检查点的 checkpoint.get_metadata() 提供的字典。必须是 JSON 可序列化的。

PublicAPI (测试版): 此API目前处于测试阶段,在成为稳定版本之前可能会发生变化。

方法

as_trainable

将自身转换为 tune.Trainable 类。

can_restore

检查给定目录是否包含一个可恢复的 Train 实验。

fit

运行训练。

get_dataset_config

返回此训练器的最终数据集配置的副本。

preprocess_datasets

已弃用。

restore

从之前中断/失败的运行中恢复一个 DataParallelTrainer。

setup

在调用 fit() 时执行初始设置的 Trainer。