ray.train.tensorflow.keras.ReportCheckpointCallback#

class ray.train.tensorflow.keras.ReportCheckpointCallback(*args: Any, **kwargs: Any)#

基类:_Callback

用于 Ray Train 报告和检查点的 Keras 回调。

备注

即使事件未在 report_metrics_on 中指定,指标总是与检查点一起报告。

示例

############# Using it in TrainSession ###############
from ray.air.integrations.keras import ReportCheckpointCallback
def train_loop_per_worker():
    strategy = tf.distribute.MultiWorkerMirroredStrategy()
    with strategy.scope():
        model = build_model()

    model.fit(dataset_shard, callbacks=[ReportCheckpointCallback()])
参数:
  • metrics – 要报告的指标。如果这是一个列表,每个项目描述了报告给 Keras 的指标键,并且它在相同名称下报告。如果这是一个字典,每个键是报告的名称,相应的值是报告给 Keras 的指标键。如果这是 None,则报告所有 Keras 日志。

  • report_metrics_on – 何时报告指标。必须是 Keras 事件钩子之一(去掉 on_ 前缀),例如 “train_start” 或 “predict_end”。默认为 “epoch_end”。

  • checkpoint_on – 何时保存检查点。必须是 Keras 事件钩子之一(去掉 on_ 前缀),例如 “train_start” 或 “predict_end”。默认为 “epoch_end”。

PublicAPI (alpha): 此API处于alpha阶段,可能在稳定之前发生变化。

方法