ray.train.huggingface.transformers.RayTrainReportCallback#

class ray.train.huggingface.transformers.RayTrainReportCallback(*args: Any, **kwargs: Any)[源代码]#

基类:TrainerCallback

一个简单的回调函数,用于向 Ray Tarin 报告检查点和指标。

此回调是 transformers.TrainerCallback 的子类,并重写了 TrainerCallback.on_save() 方法。在新检查点保存后,它从 TrainerState.log_history 获取最新的指标字典,并将其与最新的检查点一起报告给 Ray Train。

检查点将按以下结构保存:

checkpoint_00000*/   Ray Train Checkpoint
└─ checkpoint/       Hugging Face Transformers Checkpoint

对于自定义的报告和检查点逻辑,请按照此用户指南实现您自己的 transformers.TrainerCallback保存和加载检查点

请注意,用户应确保日志记录、评估和保存频率已正确配置,以便在 transformers.Trainer 保存检查点时,监控指标始终是最新的。

假设监控指标是在评估阶段报告的:

一些有效的配置:
  • evaluation_strategy == save_strategy == “epoch”

  • evaluation_strategy == save_strategy == “steps”, save_steps % eval_steps == 0

一些无效的配置:
  • evaluation_strategy != save_strategy

  • evaluation_strategy == save_strategy == “steps”, save_steps % eval_steps != 0

PublicAPI (测试版): 此API目前处于测试阶段,在成为稳定版本之前可能会发生变化。

方法

on_save

在检查点保存后调用的事件。

属性

CHECKPOINT_NAME