ray.train.tensorflow.keras.ReportCheckpointCallback#
- class ray.train.tensorflow.keras.ReportCheckpointCallback(*args: Any, **kwargs: Any)#
基类:
_Callback
用于 Ray Train 报告和检查点的 Keras 回调。
备注
即使事件未在
report_metrics_on
中指定,指标总是与检查点一起报告。示例
############# Using it in TrainSession ############### from ray.air.integrations.keras import ReportCheckpointCallback def train_loop_per_worker(): strategy = tf.distribute.MultiWorkerMirroredStrategy() with strategy.scope(): model = build_model() model.fit(dataset_shard, callbacks=[ReportCheckpointCallback()])
- 参数:
metrics – 要报告的指标。如果这是一个列表,每个项目描述了报告给 Keras 的指标键,并且它在相同名称下报告。如果这是一个字典,每个键是报告的名称,相应的值是报告给 Keras 的指标键。如果这是 None,则报告所有 Keras 日志。
report_metrics_on – 何时报告指标。必须是 Keras 事件钩子之一(去掉
on_
前缀),例如 “train_start” 或 “predict_end”。默认为 “epoch_end”。checkpoint_on – 何时保存检查点。必须是 Keras 事件钩子之一(去掉
on_
前缀),例如 “train_start” 或 “predict_end”。默认为 “epoch_end”。
PublicAPI (alpha): 此API处于alpha阶段,可能在稳定之前发生变化。
方法