ray.train.tensorflow.TensorflowTrainer#
- class ray.train.tensorflow.TensorflowTrainer(*args, **kwargs)[源代码]#
-
用于数据并行 TensorFlow 训练的训练器。
这个训练器在多个 Ray 角色上运行
train_loop_per_worker
函数。这些角色已经配置了必要的 TensorFlow 进程组,用于分布式 TensorFlow 训练。train_loop_per_worker
函数预期接收 0 或 1 个参数:def train_loop_per_worker(): ...
def train_loop_per_worker(config: Dict): ...
如果
train_loop_per_worker
接受一个参数,那么train_loop_config
将被作为参数传递进去。如果你想将train_loop_config
中的值作为超参数进行调整,这会很有用。如果
datasets
字典包含一个训练数据集(由 “train” 键表示),那么它将被拆分为多个数据集分片,这些分片可以通过train_loop_per_worker
中的ray.train.get_dataset_shard("train")
访问。所有其他数据集将不会被拆分,ray.train.get_dataset_shard(...)
将返回整个数据集。在
train_loop_per_worker
函数内部,你可以使用任何 Ray Train 循环方法。警告
Ray will not automatically set any environment variables or configuration related to local parallelism / threading aside from “OMP_NUM_THREADS”. If you desire greater control over TensorFlow threading, use the
tf.config.threading
module (eg.tf.config.threading.set_inter_op_parallelism_threads(num_cpus)
) at the beginning of yourtrain_loop_per_worker
function.from ray import train def train_loop_per_worker(): # Report intermediate results for callbacks or logging and # checkpoint data. train.report(...) # Returns dict of last saved checkpoint. train.get_checkpoint() # Returns the Dataset shard for the given key. train.get_dataset_shard("my_dataset") # Returns the total number of workers executing training. train.get_context().get_world_size() # Returns the rank of this worker. train.get_context().get_world_rank() # Returns the rank of the worker on the current node. train.get_context().get_local_rank()
从
train_loop_per_worker
返回的任何内容都将被丢弃,不会在任何地方使用或保存。要将模型保存以供
TensorflowPredictor
使用,您必须将其保存在传递给train.report()
的Checkpoint
中的 “model” kwarg 下。示例:
import os import tempfile import tensorflow as tf import ray from ray import train from ray.train import Checkpoint, ScalingConfig from ray.train.tensorflow import TensorflowTrainer def build_model(): # toy neural network : 1-layer return tf.keras.Sequential( [tf.keras.layers.Dense( 1, activation="linear", input_shape=(1,))] ) def train_loop_per_worker(config): dataset_shard = train.get_dataset_shard("train") strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() with strategy.scope(): model = build_model() model.compile( optimizer="Adam", loss="mean_squared_error", metrics=["mse"]) tf_dataset = dataset_shard.to_tf( feature_columns="x", label_columns="y", batch_size=1 ) for epoch in range(config["num_epochs"]): model.fit(tf_dataset) # Create checkpoint. checkpoint_dir = tempfile.mkdtemp() model.save_weights( os.path.join(checkpoint_dir, "my_checkpoint") ) checkpoint = Checkpoint.from_directory(checkpoint_dir) train.report( {}, checkpoint=checkpoint, ) train_dataset = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)]) trainer = TensorflowTrainer( train_loop_per_worker=train_loop_per_worker, scaling_config=ScalingConfig(num_workers=3, use_gpu=True), datasets={"train": train_dataset}, train_loop_config={"num_epochs": 2}, ) result = trainer.fit()
- 参数:
train_loop_per_worker – 要执行的训练函数。这可以不带参数或接受一个
config
字典。train_loop_config – 如果
train_loop_per_worker
接受参数,则传递给它的配置。tensorflow_config – 用于设置 TensorFlow 后端的配置。如果设置为 None,则使用默认配置。这取代了
DataParallelTrainer
的backend_config
参数。scaling_config – 数据并行训练的配置方式。
dataset_config – 数据集摄取的配置。
run_config – 训练运行的配置。
datasets – 用于训练的任何数据集。使用键 “train” 来表示哪个数据集是训练数据集。
resume_from_checkpoint – 用于从中恢复训练的检查点。
metadata – 应通过
ray.train.get_context().get_metadata()
和从该训练器保存的检查点的checkpoint.get_metadata()
提供的字典。必须是 JSON 可序列化的。
PublicAPI (测试版): 此API目前处于测试阶段,在成为稳定版本之前可能会发生变化。
方法
将自身转换为
tune.Trainable
类。检查给定目录是否包含一个可恢复的 Train 实验。
运行训练。
返回此训练器的最终数据集配置的副本。
已弃用。
从之前中断/失败的运行中恢复一个 DataParallelTrainer。
在调用 fit() 时执行初始设置的 Trainer。