- class darts.utils.callbacks.TFMProgressBar(enable_sanity_check_bar=True, enable_train_bar=True, enable_validation_bar=True, enable_prediction_bar=True, enable_train_bar_only=False, **kwargs)[源代码]¶
基类:
TQDMProgressBar
Darts 的 TorchForecastingModels 进度条。
允许自定义为哪些模型阶段(完整性检查、训练、验证、预测)显示进度条。
这个类是一个 PyTorch Lightning 回调 ,可以通过 pl_trainer_kwargs 参数传递给 TorchForecastingModel 构造函数。
实际案例
>>> from darts.models import NBEATSModel >>> from darts.utils.callbacks import TFMProgressBar >>> # only display the training bar and not the validation, prediction, and sanity check bars >>> prog_bar = TFMProgressBar(enable_train_bar_only=True) >>> model = NBEATSModel(1, 1, pl_trainer_kwargs={"callbacks": [prog_bar]})
- 参数
enable_sanity_check_bar (
bool
) – 是否为健全性检查启用进度条。enable_train_bar (
bool
) – 是否为训练启用进度条。enable_validation_bar (
bool
) – 是否为验证启用进度条。enable_prediction_bar (
bool
) – 是否为预测启用进度条。enable_train_bar_only (
bool
) – 是否禁用除训练进度条之外的所有进度条。**kwargs – 传递给 PyTorch Lightning 的 TQDMProgressBar 的参数。
属性
回调状态的标识符。
预测批次的总数,对于当前的数据加载器,这可能会随着每个epoch而变化。
当前数据加载器中,测试批次的总数,可能会随着每个epoch而变化。
训练批次的总数,可能会随着每个epoch而变化。
验证批次的总数,对于所有验证数据加载器,这可能会随着每个epoch而变化。
验证批次的总数,对于当前的数据加载器,这可能会随着每个epoch而变化。
is_disabled
is_enabled
预测描述
预测进度条
process_position
刷新率
sanity_check_description
测试描述
测试进度条
train_description
train_progress_bar
训练者
val_progress_bar
验证描述
方法
disable
()你应该提供一种禁用进度条的方法。
enable
()你应该提供一种启用进度条的方法。
get_metrics
(trainer, pl_module)将训练器收集的进度条指标与 get_standard_metrics 中的标准指标结合起来。
覆盖此项以自定义预测时的 tqdm 进度条。
覆盖此项以自定义验证初始运行时的 tqdm 进度条。
覆盖此项以自定义测试的 tqdm 进度条。
覆盖此项以自定义训练的 tqdm 进度条。
覆盖此项以自定义验证的 tqdm 进度条。
load_state_dict
(state_dict)在加载检查点时调用,实现以重新加载给定回调的
state_dict
的回调状态。on_after_backward
(trainer, pl_module)在
loss.backward()
之后和优化器步进之前调用。on_before_backward
(trainer, pl_module, loss)在
loss.backward()
之前调用。on_before_optimizer_step
(trainer, pl_module, ...)在
optimizer.step()
之前调用。on_before_zero_grad
(trainer, pl_module, ...)在
optimizer.zero_grad()
之前调用。on_exception
(trainer, pl_module, exception)当任何训练执行因异常而中断时调用。
on_fit_end
(trainer, pl_module)当拟合结束时调用。
on_fit_start
(trainer, pl_module)当拟合开始时调用。
on_load_checkpoint
(trainer, pl_module, ...)在加载模型检查点时调用,用于重新加载状态。
on_predict_batch_end
(trainer, pl_module, ...)当预测批次结束时调用。
on_predict_batch_start
(trainer, pl_module, ...)在预测批次开始时调用。
on_predict_end
(trainer, pl_module)预测结束时调用。
on_predict_epoch_end
(trainer, pl_module)在预测时期结束时调用。
on_predict_epoch_start
(trainer, pl_module)在预测周期开始时调用。
on_predict_start
(trainer, pl_module)当预测开始时调用。
当验证健全性检查结束时调用。
在验证健全性检查开始时调用。
on_save_checkpoint
(trainer, pl_module, ...)在保存检查点时调用,以便您有机会存储任何您可能想要保存的其他内容。
on_test_batch_end
(trainer, pl_module, ...[, ...])在测试批次结束时调用。
on_test_batch_start
(trainer, pl_module, ...)在测试批次开始时调用。
on_test_end
(trainer, pl_module)测试结束时调用。
on_test_epoch_end
(trainer, pl_module)在测试时期结束时调用。
on_test_epoch_start
(trainer, pl_module)在测试时期开始时调用。
on_test_start
(trainer, pl_module)测试开始时调用。
on_train_batch_end
(trainer, pl_module, ...)在训练批次结束时调用。
on_train_batch_start
(trainer, pl_module, ...)当训练批次开始时调用。
on_train_end
(*_)当列车结束时调用。
on_train_epoch_end
(trainer, pl_module)在训练轮次结束时调用。
on_train_epoch_start
(trainer, *_)在训练轮次开始时调用。
on_train_start
(*_)当列车开始时调用。
on_validation_batch_end
(trainer, pl_module, ...)当验证批次结束时调用。
on_validation_batch_start
(trainer, ...[, ...])在验证批次开始时调用。
on_validation_end
(trainer, pl_module)当验证循环结束时调用。
on_validation_epoch_end
(trainer, pl_module)在 val 时期结束时调用。
on_validation_epoch_start
(trainer, pl_module)在 val 纪元开始时调用。
on_validation_start
(trainer, pl_module)在验证循环开始时调用。
print
(*args[, sep])你应该提供一种在不中断进度条的情况下进行打印的方法。
setup
(trainer, pl_module, stage)在 fit、validate、test、predict 或 tune 开始时调用。
在保存检查点时调用,实现以生成回调的
state_dict
。teardown
(trainer, pl_module, stage)当fit、validate、test、predict或tune结束时调用。
has_dataloader_changed
reset_dataloader_idx_tracker
- BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]'¶
- disable()¶
你应该提供一种禁用进度条的方法。
- 返回类型
None
- enable()¶
你应该提供一种启用进度条的方法。
在例如 学习率查找器 这样的预训练例程中,
Trainer
会调用此方法来暂时启用和禁用训练进度条。- 返回类型
None
- get_metrics(trainer, pl_module)¶
将训练器收集的进度条指标与 get_standard_metrics 的标准指标结合。实现此功能以覆盖进度条中显示的项目。
以下是如何覆盖默认设置的示例:
def get_metrics(self, trainer, model): # don't show the version number items = super().get_metrics(trainer, model) items.pop("v_num", None) return items
- 返回类型
Dict
[str
,Union
[int
,str
,float
,Dict
[str
,float
]]]- 返回
包含要在进度条中显示的项目的字典。
- has_dataloader_changed(dataloader_idx)¶
- 返回类型
bool
- init_test_tqdm()¶
覆盖此项以自定义测试的 tqdm 进度条。
- 返回类型
Tqdm
- property is_disabled: bool¶
- 返回类型
bool
- property is_enabled: bool¶
- 返回类型
bool
- load_state_dict(state_dict)¶
在加载检查点时调用,实现以重新加载给定回调的
state_dict
的回调状态。- 参数
state_dict (
Dict
[str
,Any
]) –state_dict
返回的回调状态。- 返回类型
None
- on_after_backward(trainer, pl_module)¶
在
loss.backward()
之后和优化器步进之前调用。- 返回类型
None
- on_before_backward(trainer, pl_module, loss)¶
在
loss.backward()
之前调用。- 返回类型
None
- on_before_optimizer_step(trainer, pl_module, optimizer)¶
在
optimizer.step()
之前调用。- 返回类型
None
- on_before_zero_grad(trainer, pl_module, optimizer)¶
在
optimizer.zero_grad()
之前调用。- 返回类型
None
- on_exception(trainer, pl_module, exception)¶
当任何训练执行因异常而中断时调用。
- 返回类型
None
- on_fit_end(trainer, pl_module)¶
当拟合结束时调用。
- 返回类型
None
- on_fit_start(trainer, pl_module)¶
当拟合开始时调用。
- 返回类型
None
- on_load_checkpoint(trainer, pl_module, checkpoint)¶
在加载模型检查点时调用,用于重新加载状态。
- 参数
trainer (
Trainer
) – 当前的Trainer
实例。pl_module (
LightningModule
) – 当前的LightningModule
实例。checkpoint (
Dict
[str
,Any
]) – 由 Trainer 加载的完整检查点字典。
- 返回类型
None
- on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)¶
当预测批次结束时调用。
- 返回类型
None
- on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)¶
在预测批次开始时调用。
- 返回类型
None
- on_predict_end(trainer, pl_module)¶
预测结束时调用。
- 返回类型
None
- on_predict_epoch_end(trainer, pl_module)¶
在预测时期结束时调用。
- 返回类型
None
- on_predict_epoch_start(trainer, pl_module)¶
在预测周期开始时调用。
- 返回类型
None
- on_predict_start(trainer, pl_module)¶
当预测开始时调用。
- 返回类型
None
- on_sanity_check_end(*_)¶
当验证健全性检查结束时调用。
- 返回类型
None
- on_sanity_check_start(*_)¶
在验证健全性检查开始时调用。
- 返回类型
None
- on_save_checkpoint(trainer, pl_module, checkpoint)¶
在保存检查点时调用,以便您有机会存储任何您可能想要保存的其他内容。
- 参数
trainer (
Trainer
) – 当前的Trainer
实例。pl_module (
LightningModule
) – 当前的LightningModule
实例。checkpoint (
Dict
[str
,Any
]) – 将被保存的检查点字典。
- 返回类型
None
- on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)¶
在测试批次结束时调用。
- 返回类型
None
- on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)¶
在测试批次开始时调用。
- 返回类型
None
- on_test_end(trainer, pl_module)¶
测试结束时调用。
- 返回类型
None
- on_test_epoch_end(trainer, pl_module)¶
在测试时期结束时调用。
- 返回类型
None
- on_test_epoch_start(trainer, pl_module)¶
在测试时期开始时调用。
- 返回类型
None
- on_test_start(trainer, pl_module)¶
测试开始时调用。
- 返回类型
None
- on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)¶
在训练批次结束时调用。
备注
这里的
outputs["loss"]
值将是相对于accumulate_grad_batches
的损失归一化值,该损失从training_step
返回。- 返回类型
None
- on_train_batch_start(trainer, pl_module, batch, batch_idx)¶
当训练批次开始时调用。
- 返回类型
None
- on_train_end(*_)¶
当列车结束时调用。
- 返回类型
None
- on_train_epoch_end(trainer, pl_module)¶
在训练轮次结束时调用。
要在每个epoch结束时访问所有批处理输出,您可以将步骤输出缓存为
pytorch_lightning.core.LightningModule
的属性,并在此钩子中访问它们:class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear()
- 返回类型
None
- on_train_epoch_start(trainer, *_)¶
在训练轮次开始时调用。
- 返回类型
None
- on_train_start(*_)¶
当列车开始时调用。
- 返回类型
None
- on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)¶
当验证批次结束时调用。
- 返回类型
None
- on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)¶
在验证批次开始时调用。
- 返回类型
None
- on_validation_end(trainer, pl_module)¶
当验证循环结束时调用。
- 返回类型
None
- on_validation_epoch_end(trainer, pl_module)¶
在 val 时期结束时调用。
- 返回类型
None
- on_validation_epoch_start(trainer, pl_module)¶
在 val 纪元开始时调用。
- 返回类型
None
- on_validation_start(trainer, pl_module)¶
在验证循环开始时调用。
- 返回类型
None
- property predict_description: str¶
- 返回类型
str
- property predict_progress_bar: tqdm_asyncio¶
- 返回类型
tqdm_asyncio
- print(*args, sep=' ', **kwargs)¶
你应该提供一种在不中断进度条的情况下进行打印的方法。
- 返回类型
None
- property process_position: int¶
- 返回类型
int
- property refresh_rate: int¶
- 返回类型
int
- reset_dataloader_idx_tracker()¶
- 返回类型
None
- property sanity_check_description: str¶
- 返回类型
str
- setup(trainer, pl_module, stage)¶
在 fit、validate、test、predict 或 tune 开始时调用。
- 返回类型
None
- state_dict()¶
在保存检查点时调用,实现以生成回调的
state_dict
。- 返回类型
Dict
[str
,Any
]- 返回
包含回调状态的字典。
- property state_key: str¶
回调状态的标识符。
用于通过
checkpoint["callbacks"][state_key]
从检查点字典中存储和检索回调的状态。如果 1) 回调有状态且 2) 希望保持该回调的多个实例的状态,则回调的实现需要提供一个唯一的状态键。- 返回类型
str
- teardown(trainer, pl_module, stage)¶
当fit、validate、test、predict或tune结束时调用。
- 返回类型
None
- property test_description: str¶
- 返回类型
str
- property test_progress_bar: tqdm_asyncio¶
- 返回类型
tqdm_asyncio
- property total_predict_batches_current_dataloader: Union[int, float]¶
预测批次的总数,对于当前的数据加载器,这可能会随着每个epoch而变化。
使用此项来设置进度条中的总迭代次数。如果预测数据加载器的大小是无限的,则可以返回
inf
。- 返回类型
Union
[int
,float
]
- property total_test_batches_current_dataloader: Union[int, float]¶
当前数据加载器中,测试批次的总数,可能会随着每个epoch而变化。
使用此项来设置进度条中的总迭代次数。如果测试数据加载器的大小是无限的,则可以返回
inf
。- 返回类型
Union
[int
,float
]
- property total_train_batches: Union[int, float]¶
训练批次的总数,可能会随着每个epoch而变化。
使用此项来设置进度条中的总迭代次数。如果训练数据加载器的大小是无限的,则可以返回
inf
。- 返回类型
Union
[int
,float
]
- property total_val_batches: Union[int, float]¶
验证批次的总数,对于所有验证数据加载器,这可能会随着每个epoch而变化。
使用此项来设置进度条中的总迭代次数。如果预测数据加载器的大小是无限的,则可以返回
inf
。- 返回类型
Union
[int
,float
]
- property total_val_batches_current_dataloader: Union[int, float]¶
验证批次的总数,对于当前的数据加载器,这可能会随着每个epoch而变化。
使用此项来设置进度条中的总迭代次数。如果验证数据加载器的大小是无限的,则可以返回
inf
。- 返回类型
Union
[int
,float
]
- property train_description: str¶
- 返回类型
str
- property train_progress_bar: tqdm_asyncio¶
- 返回类型
tqdm_asyncio
- property trainer: Trainer¶
- 返回类型
Trainer
- property val_progress_bar: tqdm_asyncio¶
- 返回类型
tqdm_asyncio
- property validation_description: str¶
- 返回类型
str