回调函数
回调是可以自定义PyTorch中训练循环行为的对象 Trainer(此功能尚未在TensorFlow中实现),它可以检查训练循环 状态(用于进度报告、在TensorBoard或其他ML平台上记录日志…)并做出决策(如提前 停止)。
回调函数是“只读”的代码片段,除了它们返回的TrainerControl对象外,它们无法改变训练循环中的任何内容。对于需要在训练循环中进行更改的自定义,你应该继承Trainer并重写你需要的方法(参见trainer中的示例)。
默认情况下,TrainingArguments.report_to
设置为 "all"
,因此 Trainer 将使用以下回调。
- DefaultFlowCallback 处理日志记录、保存和评估的默认行为。
- PrinterCallback 或 ProgressCallback 用于显示进度并打印日志(如果你通过 TrainingArguments 停用 tqdm,则使用第一个,否则使用第二个)。
- TensorBoardCallback 如果 tensorboard 可访问(通过 PyTorch >= 1.4 或 tensorboardX)。
- WandbCallback 如果安装了 wandb。
- CometCallback 如果安装了 comet_ml。
- MLflowCallback 如果安装了 mlflow。
- NeptuneCallback 如果安装了 neptune。
- AzureMLCallback 如果安装了 azureml-sdk。
- CodeCarbonCallback 如果安装了 codecarbon。
- ClearMLCallback 如果安装了 clearml。
- DagsHubCallback 如果安装了 dagshub。
- FlyteCallback 如果安装了 flyte。
- DVCLiveCallback 如果安装了 dvclive。
如果已安装了一个包但您不希望使用其附带的集成,您可以将TrainingArguments.report_to
更改为仅包含您希望使用的集成列表(例如["azure_ml", "wandb"]
)。
实现回调的主要类是TrainerCallback。它获取用于实例化Trainer的TrainingArguments,可以通过TrainerState访问该Trainer的内部状态,并且可以通过TrainerControl在训练循环中采取一些操作。
可用的回调
以下是库中可用的TrainerCallback列表:
一个TrainerCallback,它将日志发送到Comet ML。
设置可选的Comet集成。
环境:
- COMET_MODE (
str
, 可选, 默认为get_or_create
): 控制是创建并记录到新的Comet实验还是附加到现有实验。 它接受以下值:get_or_create
: 根据COMET_EXPERIMENT_KEY
是否设置以及是否存在具有该键的实验自动决定。create
: 始终创建一个新的Comet实验。get
: 始终尝试附加到现有的Comet实验。 需要设置COMET_EXPERIMENT_KEY
。ONLINE
: 已弃用, 用于创建在线实验。请改用COMET_START_ONLINE=1
。OFFLINE
: 已弃用, 用于创建离线实验。请改用COMET_START_ONLINE=0
。DISABLED
: 已弃用, 用于禁用Comet记录。请改用--report_to
标志来控制用于记录结果的集成。
- COMET_PROJECT_NAME (
str
, 可选): 用于实验的Comet项目名称。 - COMET_LOG_ASSETS (
str
, 可选, 默认为TRUE
): 是否将训练资产(tf事件日志、检查点等)记录到Comet。可以是TRUE
或FALSE
。
有关环境中许多可配置项的信息,请参见 这里。
一个TrainerCallback,用于处理日志、评估和检查点的训练循环的默认流程。
一个简单的TrainerCallback,仅用于打印日志。
一个TrainerCallback,用于显示训练或评估的进度。
你可以修改max_str_len
来控制日志记录时字符串的截断长度。
类 transformers.EarlyStoppingCallback
< source >( early_stopping_patience: int = 1 early_stopping_threshold: typing.Optional[float] = 0.0 )
一个TrainerCallback,用于处理提前停止。
此回调依赖于TrainingArguments参数load_best_model_at_end功能来设置TrainerState中的best_metric。请注意,如果TrainingArguments参数save_steps与eval_steps不同,则早期停止将不会发生,直到下一个保存步骤。
类 transformers.integrations.TensorBoardCallback
< source >( tb_writer = 无 )
一个TrainerCallback,它将日志发送到TensorBoard。
一个TrainerCallback,用于将指标、媒体、模型检查点记录到Weight and Biases。
设置可选的Weights & Biases(wandb)集成。
如果需要,可以子类化并重写此方法以自定义设置。更多信息请参见 这里。您还可以重写以下环境 变量:
环境:
WANDB_LOG_MODEL (
str
, 可选, 默认为"false"
): 是否在训练期间记录模型和检查点。可以是"end"
,"checkpoint"
或"false"
。如果设置为"end"
,模型将在训练结束时上传。如果设置为"checkpoint"
,检查点 将每args.save_steps
上传一次。如果设置为"false"
,模型将不会上传。与load_best_model_at_end()
一起使用以上传最佳模型。在5.0版本中已弃用
在🤗 Transformers的第5版中,将
WANDB_LOG_MODEL
设置为bool
将被弃用。WANDB_WATCH (
str
, 可选 默认为"false"
): 可以是"gradients"
,"all"
,"parameters"
, 或"false"
。设置为"all"
以记录梯度和参数。WANDB_PROJECT (
str
, 可选, 默认为"huggingface"
): 将此设置为自定义字符串以将结果存储在不同的项目中。WANDB_DISABLED (
bool
, 可选, 默认为False
): 是否完全禁用wandb。设置WANDB_DISABLED=true
来禁用。
一个TrainerCallback,将日志发送到MLflow。可以通过设置环境变量DISABLE_MLFLOW_INTEGRATION = TRUE
来禁用。
设置可选的MLflow集成。
环境:
- HF_MLFLOW_LOG_ARTIFACTS (
str
, 可选): 是否使用 MLflow 的.log_artifact()
功能来记录工件。这仅在记录到远程服务器(例如 s3 或 GCS)时才有意义。如果设置为True
或 1,则会在每次保存时将 TrainingArguments 的output_dir
中的每个保存的检查点复制到本地或远程工件存储。在没有远程存储的情况下使用它只会将文件复制到您的工件位置。 - MLFLOW_TRACKING_URI (
str
, 可选): 是否在特定路径或远程服务器上存储运行。默认未设置,这将完全跳过设置跟踪URI。 - MLFLOW_EXPERIMENT_NAME (
str
, 可选, 默认为None
): 是否使用一个MLflow实验名称来启动运行。默认为None
,这将指向MLflow中的Default
实验。否则,它是一个区分大小写的实验名称,该实验将被激活。如果不存在具有此名称的实验,则会创建一个具有此名称的新实验。 - MLFLOW_TAGS (
str
, 可选): 一个键值对字典的字符串转储,将作为标签添加到MLflow运行中。示例:os.environ['MLFLOW_TAGS']='{"release.candidate": "RC1", "release.version": "2.2.0"}'
. - MLFLOW_NESTED_RUN (
str
, 可选): 是否使用MLflow嵌套运行。如果设置为True
或1,将在当前运行中创建一个嵌套运行。 - MLFLOW_RUN_ID (
str
, 可选): 允许重新附加到现有的运行,这在从检查点恢复训练时非常有用。当设置了MLFLOW_RUN_ID
环境变量时,start_run
尝试恢复具有指定运行ID的运行,并忽略其他参数。 - MLFLOW_FLATTEN_PARAMS (
str
, 可选, 默认为False
): 是否在记录之前将参数字典展平。 - MLFLOW_MAX_LOG_PARAMS (
int
, 可选): 设置运行中记录的最大参数数量。
一个TrainerCallback,将日志发送到AzureML。
一个TrainerCallback,用于跟踪训练的CO2排放。
类 transformers.integrations.NeptuneCallback
< source >( api_token: typing.Optional[str] = None project: typing.Optional[str] = None name: typing.Optional[str] = None base_namespace: str = 'finetuning' run = None log_parameters: bool = True log_checkpoints: typing.Optional[str] = None **neptune_run_kwargs )
参数
- api_token (
str
, 可选) — 注册后获得的Neptune API令牌。 如果您已将令牌保存到NEPTUNE_API_TOKEN
环境变量中(强烈推荐),则可以省略此参数。完整的设置说明请参见docs. - 项目 (
str
, 可选) — 现有Neptune项目的名称,格式为“工作区名称/项目名称”。 您可以在Neptune的项目设置 -> 属性中找到并复制该名称。如果为None(默认值),则使用NEPTUNE_PROJECT
环境变量的值。 - name (
str
, optional) — 运行的自定义名称. - base_namespace (
str
, optional, 默认为“finetuning”) — 在Neptune运行中,将包含回调记录的所有元数据的根命名空间。 - log_parameters (
bool
, 可选, 默认为True
) — 如果为True,则记录由Trainer提供的所有Trainer参数和模型参数。 - log_checkpoints (
str
, 可选) — 如果为“same”,则每当Trainer保存检查点时上传。 如果为“last”,则仅上传最近保存的检查点。如果为“best”,则上传最佳检查点(在Trainer保存的检查点中)。如果为None
,则不上传检查点。 - run (
Run
, 可选) — 如果你想继续记录到一个现有的运行中,请传递一个Neptune运行对象。 在docs中了解更多关于恢复运行的信息。 - **neptune_run_kwargs (可选) —
额外的关键字参数,当创建新运行时直接传递给
neptune.init_run()
函数。
TrainerCallback 将日志发送到 Neptune。
有关说明和示例,请参阅Neptune文档中的Transformers集成指南。
一个TrainerCallback,将日志发送到ClearML。
环境:
- CLEARML_PROJECT (
str
, 可选, 默认为HuggingFace Transformers
): ClearML 项目名称。 - CLEARML_TASK (
str
, 可选, 默认为Trainer
): ClearML 任务名称。 - CLEARML_LOG_MODEL (
bool
, 可选, 默认为False
): 是否在训练期间将模型记录为工件。
一个TrainerCallback,用于记录到DagsHub。扩展自MLflowCallback
设置DagsHub的日志记录集成。
环境:
- HF_DAGSHUB_LOG_ARTIFACTS (
str
, 可选): 是否保存实验的数据和模型工件。默认为False
。
类 transformers.integrations.FlyteCallback
< source >( save_log_history: bool = True sync_checkpoints: bool = True )
一个TrainerCallback,它将日志发送到Flyte。 注意:此回调仅在Flyte任务内有效。
类 transformers.integrations.DVCLiveCallback
< source >( live: typing.Optional[typing.Any] = None log_model: typing.Union[typing.Literal['all'], bool, NoneType] = None **kwargs )
参数
- live (
dvclive.Live
, 可选, 默认为None
) — 可选的 Live 实例。如果为 None,将使用 **kwargs 创建一个新实例。 - log_model (Union[Literal[“all”], bool], 可选, 默认为
None
) — 是否使用dvclive.Live.log_artifact()
来记录由 Trainer 创建的检查点。如果设置为True
, 则在训练结束时记录最终检查点。如果设置为"all"
,则在每个检查点记录整个 TrainingArguments 的output_dir
。
一个TrainerCallback,将日志发送到DVCLive。
在setup
中使用以下环境变量来配置集成。要自定义此回调超出这些环境变量的范围,请参见这里。
设置可选的DVCLive集成。要自定义此回调超出以下环境变量的范围,请参阅 这里。
环境:
- HF_DVCLIVE_LOG_MODEL (
str
, 可选): 是否使用dvclive.Live.log_artifact()
来记录由 Trainer 创建的检查点。如果设置为True
或 1,则在训练结束时记录最终的检查点。如果设置为all
,则在每个检查点时记录整个 TrainingArguments 的output_dir
。
TrainerCallback
类 transformers.TrainerCallback
< source >( )
参数
- args (TrainingArguments) — 用于实例化Trainer的训练参数。
- state (TrainerState) — Trainer的当前状态.
- 控制 (TrainerControl) — 返回给Trainer的对象,可用于做出一些决策。
- model (PreTrainedModel or
torch.nn.Module
) — 正在训练的模型。 - tokenizer (PreTrainedTokenizer) —
用于编码数据的标记器。这已被弃用,推荐使用
processing_class
。 - processing_class ([
PreTrainedTokenizer
或BaseImageProcessor
或ProcessorMixin
或FeatureExtractionMixin
]) — 用于编码数据的处理类。可以是分词器、处理器、图像处理器或特征提取器。 - optimizer (
torch.optim.Optimizer
) — 用于训练步骤的优化器。 - lr_scheduler (
torch.optim.lr_scheduler.LambdaLR
) — 用于设置学习率的调度器。 - train_dataloader (
torch.utils.data.DataLoader
, optional) — 当前用于训练的数据加载器。 - eval_dataloader (
torch.utils.data.DataLoader
, optional) — 当前用于评估的数据加载器。 - metrics (
Dict[str, float]
) — The metrics computed by the last evaluation phase.这些仅在事件
on_evaluate
中可访问。 - logs (
Dict[str, float]
) — The values to log.这些只能在事件
on_log
中访问。
一个用于对象的类,这些对象将在某些事件中检查训练循环的状态并做出一些决策。在每个这些事件中,以下参数是可用的:
control
对象是唯一一个可以通过回调函数更改的对象,在这种情况下,更改它的事件应返回修改后的版本。
参数 args
、state
和 control
是所有事件的位置参数,其他所有参数都分组在 kwargs
中。
你可以在事件的签名中使用它们来解包你需要的参数。例如,请参阅简单的 PrinterCallback 的代码。
示例:
class PrinterCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
_ = logs.pop("total_flos", None)
if state.is_local_process_zero:
print(logs)
on_epoch_begin
< source >( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )
在纪元开始时调用的事件。
on_epoch_end
< source >( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )
在一个周期结束时调用的事件。
on_evaluate
< source >( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )
在评估阶段之后调用的事件。
on_init_end
< source >( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )
在Trainer初始化结束时调用的事件。
在记录最后日志后调用的事件。
on_optimizer_step
< source >( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )
在优化器步骤之后但在梯度清零之前调用的事件。用于监控梯度。
on_pre_optimizer_step
< source >( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )
在优化器步骤之前但在梯度裁剪之后调用的事件。用于监控梯度。
on_predict
< source >( args: TrainingArguments state: TrainerState control: TrainerControl metrics **kwargs )
成功预测后调用的事件。
on_prediction_step
< source >( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )
在预测步骤之后调用的事件。
在保存检查点后调用的事件。
on_step_begin
< source >( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )
在训练步骤开始时调用的事件。如果使用梯度累积,一个训练步骤可能需要多个输入。
on_step_end
< source >( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )
在训练步骤结束时调用的事件。如果使用梯度累积,一个训练步骤可能需要多个输入。
on_substep_end
< source >( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )
在梯度累积的子步骤结束时调用的事件。
在训练开始时调用的事件。
训练结束时调用的事件。
以下是如何使用 PyTorch Trainer 注册自定义回调的示例:
class MyCallback(TrainerCallback):
"A callback that prints a message at the beginning of training"
def on_train_begin(self, args, state, control, **kwargs):
print("Starting training")
trainer = Trainer(
model,
args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
callbacks=[MyCallback], # We can either pass the callback class this way or an instance of it (MyCallback())
)
另一种注册回调的方法是调用 trainer.add_callback()
,如下所示:
trainer = Trainer(...)
trainer.add_callback(MyCallback)
# Alternatively, we can pass an instance of the callback class
trainer.add_callback(MyCallback())
TrainerState
类 transformers.TrainerState
< source >( epoch: typing.Optional[float] = None global_step: int = 0 max_steps: int = 0 logging_steps: int = 500 eval_steps: int = 500 save_steps: int = 500 train_batch_size: int = None num_train_epochs: int = 0 num_input_tokens_seen: int = 0 total_flos: float = 0 log_history: typing.List[typing.Dict[str, float]] = None best_metric: typing.Optional[float] = None best_model_checkpoint: typing.Optional[str] = None is_local_process_zero: bool = True is_world_process_zero: bool = True is_hyper_param_search: bool = False trial_name: str = None trial_params: typing.Dict[str, typing.Union[str, float, int, bool]] = None stateful_callbacks: typing.List[ForwardRef('TrainerCallback')] = None )
参数
- epoch (
float
, 可选) — 仅在训练期间设置,将表示训练所处的epoch(小数部分表示当前epoch完成的百分比)。 - global_step (
int
, optional, defaults to 0) — 在训练期间,表示已完成的更新步骤数。 - max_steps (
int
, optional, defaults to 0) — 当前训练期间要执行的更新步骤数。 - logging_steps (
int
, optional, 默认为 500) — 每 X 次更新步骤记录日志 - eval_steps (
int
, optional) — 每X步运行一次评估。 - save_steps (
int
, optional, defaults to 500) — 每X次更新步骤保存一次检查点。 - train_batch_size (
int
, optional) — 训练数据加载器的批量大小。仅在使用了auto_find_batch_size
时需要。 - num_input_tokens_seen (
int
, 可选, 默认为 0) — 当跟踪输入令牌时,训练期间看到的令牌数量(输入令牌的数量,而不是预测令牌的数量)。 - total_flos (
float
, optional, 默认为 0) — 自训练开始以来,模型完成的浮点操作总数(存储为浮点数以避免溢出)。 - log_history (
List[Dict[str, float]]
, optional) — 自训练开始以来完成的日志列表。 - best_metric (
float
, optional) — 在跟踪最佳模型时,迄今为止遇到的最佳指标的值。 - best_model_checkpoint (
str
, optional) — 在跟踪最佳模型时,到目前为止遇到的最佳模型的检查点名称的值。 - is_local_process_zero (
bool
, 可选, 默认为True
) — 此进程是否为本地(例如,如果在多台机器上以分布式方式进行训练,则在一台机器上)主进程。 - is_world_process_zero (
bool
, 可选, 默认为True
) — 此进程是否为全局主进程(在多台机器上以分布式方式进行训练时,只有一个进程会是True
)。 - is_hyper_param_search (
bool
, 可选, 默认为False
) — 我们是否正在使用 Trainer.hyperparameter_search 进行超参数搜索。这将影响数据在 TensorBoard 中的记录方式。 - stateful_callbacks (
List[StatefulTrainerCallback]
, 可选) — 附加到Trainer
的回调函数,其状态应被保存或恢复。 相关的回调函数应实现state
和from_state
函数。
一个包含Trainer内部状态的类,该状态将在检查点时与模型和优化器一起保存,并传递给TrainerCallback。
在这个类中,一个步骤被理解为一个更新步骤。当使用梯度累积时,一个更新步骤可能需要多次前向和后向传递:如果你使用gradient_accumulation_steps=n
,那么一个更新步骤需要经过n个批次。
从json_path
的内容创建一个实例。
将此实例的内容以JSON格式保存在json_path
中。
TrainerControl
类 transformers.TrainerControl
< source >( should_training_stop: bool = False should_epoch_stop: bool = False should_save: bool = False should_evaluate: bool = False should_log: bool = False )
参数
- should_training_stop (
bool
, optional, defaults toFalse
) — Whether or not the training should be interrupted.如果
True
,此变量将不会被重置为False
。训练将直接停止。 - should_epoch_stop (
bool
, optional, defaults toFalse
) — Whether or not the current epoch should be interrupted.如果为
True
,此变量将在下一个周期开始时重置为False
。 - should_save (
bool
, optional, defaults toFalse
) — Whether or not the model should be saved at this step.如果
True
,此变量将在下一步开始时重置为False
。 - should_evaluate (
bool
, optional, defaults toFalse
) — Whether or not the model should be evaluated at this step.如果
True
,此变量将在下一步开始时重置为False
。 - should_log (
bool
, optional, defaults toFalse
) — Whether or not the logs should be reported at this step.如果
True
,此变量将在下一步开始时重置为False
。
一个处理Trainer控制流的类。这个类由TrainerCallback使用,以在训练循环中激活一些开关。