class darts.utils.callbacks.TFMProgressBar(enable_sanity_check_bar=True, enable_train_bar=True, enable_validation_bar=True, enable_prediction_bar=True, enable_train_bar_only=False, **kwargs)[源代码]

基类:TQDMProgressBar

Darts 的 TorchForecastingModels 进度条。

允许自定义为哪些模型阶段(完整性检查、训练、验证、预测)显示进度条。

这个类是一个 PyTorch Lightning 回调 ,可以通过 pl_trainer_kwargs 参数传递给 TorchForecastingModel 构造函数。

实际案例

>>> from darts.models import NBEATSModel
>>> from darts.utils.callbacks import TFMProgressBar
>>> # only display the training bar and not the validation, prediction, and sanity check bars
>>> prog_bar = TFMProgressBar(enable_train_bar_only=True)
>>> model = NBEATSModel(1, 1, pl_trainer_kwargs={"callbacks": [prog_bar]})
参数
  • enable_sanity_check_bar (bool) – 是否为健全性检查启用进度条。

  • enable_train_bar (bool) – 是否为训练启用进度条。

  • enable_validation_bar (bool) – 是否为验证启用进度条。

  • enable_prediction_bar (bool) – 是否为预测启用进度条。

  • enable_train_bar_only (bool) – 是否禁用除训练进度条之外的所有进度条。

  • **kwargs – 传递给 PyTorch Lightning 的 TQDMProgressBar 的参数。

属性

state_key

回调状态的标识符。

total_predict_batches_current_dataloader

预测批次的总数,对于当前的数据加载器,这可能会随着每个epoch而变化。

total_test_batches_current_dataloader

当前数据加载器中,测试批次的总数,可能会随着每个epoch而变化。

total_train_batches

训练批次的总数,可能会随着每个epoch而变化。

total_val_batches

验证批次的总数,对于所有验证数据加载器,这可能会随着每个epoch而变化。

total_val_batches_current_dataloader

验证批次的总数,对于当前的数据加载器,这可能会随着每个epoch而变化。

is_disabled

is_enabled

预测描述

预测进度条

process_position

刷新率

sanity_check_description

测试描述

测试进度条

train_description

train_progress_bar

训练者

val_progress_bar

验证描述

方法

disable()

你应该提供一种禁用进度条的方法。

enable()

你应该提供一种启用进度条的方法。

get_metrics(trainer, pl_module)

将训练器收集的进度条指标与 get_standard_metrics 中的标准指标结合起来。

init_predict_tqdm()

覆盖此项以自定义预测时的 tqdm 进度条。

init_sanity_tqdm()

覆盖此项以自定义验证初始运行时的 tqdm 进度条。

init_test_tqdm()

覆盖此项以自定义测试的 tqdm 进度条。

init_train_tqdm()

覆盖此项以自定义训练的 tqdm 进度条。

init_validation_tqdm()

覆盖此项以自定义验证的 tqdm 进度条。

load_state_dict(state_dict)

在加载检查点时调用,实现以重新加载给定回调的 state_dict 的回调状态。

on_after_backward(trainer, pl_module)

loss.backward() 之后和优化器步进之前调用。

on_before_backward(trainer, pl_module, loss)

loss.backward() 之前调用。

on_before_optimizer_step(trainer, pl_module, ...)

optimizer.step() 之前调用。

on_before_zero_grad(trainer, pl_module, ...)

optimizer.zero_grad() 之前调用。

on_exception(trainer, pl_module, exception)

当任何训练执行因异常而中断时调用。

on_fit_end(trainer, pl_module)

当拟合结束时调用。

on_fit_start(trainer, pl_module)

当拟合开始时调用。

on_load_checkpoint(trainer, pl_module, ...)

在加载模型检查点时调用,用于重新加载状态。

on_predict_batch_end(trainer, pl_module, ...)

当预测批次结束时调用。

on_predict_batch_start(trainer, pl_module, ...)

在预测批次开始时调用。

on_predict_end(trainer, pl_module)

预测结束时调用。

on_predict_epoch_end(trainer, pl_module)

在预测时期结束时调用。

on_predict_epoch_start(trainer, pl_module)

在预测周期开始时调用。

on_predict_start(trainer, pl_module)

当预测开始时调用。

on_sanity_check_end(*_)

当验证健全性检查结束时调用。

on_sanity_check_start(*_)

在验证健全性检查开始时调用。

on_save_checkpoint(trainer, pl_module, ...)

在保存检查点时调用,以便您有机会存储任何您可能想要保存的其他内容。

on_test_batch_end(trainer, pl_module, ...[, ...])

在测试批次结束时调用。

on_test_batch_start(trainer, pl_module, ...)

在测试批次开始时调用。

on_test_end(trainer, pl_module)

测试结束时调用。

on_test_epoch_end(trainer, pl_module)

在测试时期结束时调用。

on_test_epoch_start(trainer, pl_module)

在测试时期开始时调用。

on_test_start(trainer, pl_module)

测试开始时调用。

on_train_batch_end(trainer, pl_module, ...)

在训练批次结束时调用。

on_train_batch_start(trainer, pl_module, ...)

当训练批次开始时调用。

on_train_end(*_)

当列车结束时调用。

on_train_epoch_end(trainer, pl_module)

在训练轮次结束时调用。

on_train_epoch_start(trainer, *_)

在训练轮次开始时调用。

on_train_start(*_)

当列车开始时调用。

on_validation_batch_end(trainer, pl_module, ...)

当验证批次结束时调用。

on_validation_batch_start(trainer, ...[, ...])

在验证批次开始时调用。

on_validation_end(trainer, pl_module)

当验证循环结束时调用。

on_validation_epoch_end(trainer, pl_module)

在 val 时期结束时调用。

on_validation_epoch_start(trainer, pl_module)

在 val 纪元开始时调用。

on_validation_start(trainer, pl_module)

在验证循环开始时调用。

print(*args[, sep])

你应该提供一种在不中断进度条的情况下进行打印的方法。

setup(trainer, pl_module, stage)

在 fit、validate、test、predict 或 tune 开始时调用。

state_dict()

在保存检查点时调用,实现以生成回调的 state_dict

teardown(trainer, pl_module, stage)

当fit、validate、test、predict或tune结束时调用。

has_dataloader_changed

reset_dataloader_idx_tracker

BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]'
disable()

你应该提供一种禁用进度条的方法。

返回类型

None

enable()

你应该提供一种启用进度条的方法。

在例如 学习率查找器 这样的预训练例程中,Trainer 会调用此方法来暂时启用和禁用训练进度条。

返回类型

None

get_metrics(trainer, pl_module)

将训练器收集的进度条指标与 get_standard_metrics 的标准指标结合。实现此功能以覆盖进度条中显示的项目。

以下是如何覆盖默认设置的示例:

def get_metrics(self, trainer, model):
    # don't show the version number
    items = super().get_metrics(trainer, model)
    items.pop("v_num", None)
    return items
返回类型

Dict[str, Union[int, str, float, Dict[str, float]]]

返回

包含要在进度条中显示的项目的字典。

has_dataloader_changed(dataloader_idx)
返回类型

bool

init_predict_tqdm()[源代码]

覆盖此项以自定义预测时的 tqdm 进度条。

返回类型

Tqdm

init_sanity_tqdm()[源代码]

覆盖此项以自定义验证初始运行时的 tqdm 进度条。

返回类型

Tqdm

init_test_tqdm()

覆盖此项以自定义测试的 tqdm 进度条。

返回类型

Tqdm

init_train_tqdm()[源代码]

覆盖此项以自定义训练的 tqdm 进度条。

返回类型

Tqdm

init_validation_tqdm()[源代码]

覆盖此项以自定义验证的 tqdm 进度条。

返回类型

Tqdm

property is_disabled: bool
返回类型

bool

property is_enabled: bool
返回类型

bool

load_state_dict(state_dict)

在加载检查点时调用,实现以重新加载给定回调的 state_dict 的回调状态。

参数

state_dict (Dict[str, Any]) – state_dict 返回的回调状态。

返回类型

None

on_after_backward(trainer, pl_module)

loss.backward() 之后和优化器步进之前调用。

返回类型

None

on_before_backward(trainer, pl_module, loss)

loss.backward() 之前调用。

返回类型

None

on_before_optimizer_step(trainer, pl_module, optimizer)

optimizer.step() 之前调用。

返回类型

None

on_before_zero_grad(trainer, pl_module, optimizer)

optimizer.zero_grad() 之前调用。

返回类型

None

on_exception(trainer, pl_module, exception)

当任何训练执行因异常而中断时调用。

返回类型

None

on_fit_end(trainer, pl_module)

当拟合结束时调用。

返回类型

None

on_fit_start(trainer, pl_module)

当拟合开始时调用。

返回类型

None

on_load_checkpoint(trainer, pl_module, checkpoint)

在加载模型检查点时调用,用于重新加载状态。

参数
  • trainer (Trainer) – 当前的 Trainer 实例。

  • pl_module (LightningModule) – 当前的 LightningModule 实例。

  • checkpoint (Dict[str, Any]) – 由 Trainer 加载的完整检查点字典。

返回类型

None

on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)

当预测批次结束时调用。

返回类型

None

on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)

在预测批次开始时调用。

返回类型

None

on_predict_end(trainer, pl_module)

预测结束时调用。

返回类型

None

on_predict_epoch_end(trainer, pl_module)

在预测时期结束时调用。

返回类型

None

on_predict_epoch_start(trainer, pl_module)

在预测周期开始时调用。

返回类型

None

on_predict_start(trainer, pl_module)

当预测开始时调用。

返回类型

None

on_sanity_check_end(*_)

当验证健全性检查结束时调用。

返回类型

None

on_sanity_check_start(*_)

在验证健全性检查开始时调用。

返回类型

None

on_save_checkpoint(trainer, pl_module, checkpoint)

在保存检查点时调用,以便您有机会存储任何您可能想要保存的其他内容。

参数
  • trainer (Trainer) – 当前的 Trainer 实例。

  • pl_module (LightningModule) – 当前的 LightningModule 实例。

  • checkpoint (Dict[str, Any]) – 将被保存的检查点字典。

返回类型

None

on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)

在测试批次结束时调用。

返回类型

None

on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)

在测试批次开始时调用。

返回类型

None

on_test_end(trainer, pl_module)

测试结束时调用。

返回类型

None

on_test_epoch_end(trainer, pl_module)

在测试时期结束时调用。

返回类型

None

on_test_epoch_start(trainer, pl_module)

在测试时期开始时调用。

返回类型

None

on_test_start(trainer, pl_module)

测试开始时调用。

返回类型

None

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)

在训练批次结束时调用。

备注

这里的 outputs["loss"] 值将是相对于 accumulate_grad_batches 的损失归一化值,该损失从 training_step 返回。

返回类型

None

on_train_batch_start(trainer, pl_module, batch, batch_idx)

当训练批次开始时调用。

返回类型

None

on_train_end(*_)

当列车结束时调用。

返回类型

None

on_train_epoch_end(trainer, pl_module)

在训练轮次结束时调用。

要在每个epoch结束时访问所有批处理输出,您可以将步骤输出缓存为 pytorch_lightning.core.LightningModule 的属性,并在此钩子中访问它们:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss

class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()
返回类型

None

on_train_epoch_start(trainer, *_)

在训练轮次开始时调用。

返回类型

None

on_train_start(*_)

当列车开始时调用。

返回类型

None

on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)

当验证批次结束时调用。

返回类型

None

on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)

在验证批次开始时调用。

返回类型

None

on_validation_end(trainer, pl_module)

当验证循环结束时调用。

返回类型

None

on_validation_epoch_end(trainer, pl_module)

在 val 时期结束时调用。

返回类型

None

on_validation_epoch_start(trainer, pl_module)

在 val 纪元开始时调用。

返回类型

None

on_validation_start(trainer, pl_module)

在验证循环开始时调用。

返回类型

None

property predict_description: str
返回类型

str

property predict_progress_bar: tqdm_asyncio
返回类型

tqdm_asyncio

print(*args, sep=' ', **kwargs)

你应该提供一种在不中断进度条的情况下进行打印的方法。

返回类型

None

property process_position: int
返回类型

int

property refresh_rate: int
返回类型

int

reset_dataloader_idx_tracker()
返回类型

None

property sanity_check_description: str
返回类型

str

setup(trainer, pl_module, stage)

在 fit、validate、test、predict 或 tune 开始时调用。

返回类型

None

state_dict()

在保存检查点时调用,实现以生成回调的 state_dict

返回类型

Dict[str, Any]

返回

包含回调状态的字典。

property state_key: str

回调状态的标识符。

用于通过 checkpoint["callbacks"][state_key] 从检查点字典中存储和检索回调的状态。如果 1) 回调有状态且 2) 希望保持该回调的多个实例的状态,则回调的实现需要提供一个唯一的状态键。

返回类型

str

teardown(trainer, pl_module, stage)

当fit、validate、test、predict或tune结束时调用。

返回类型

None

property test_description: str
返回类型

str

property test_progress_bar: tqdm_asyncio
返回类型

tqdm_asyncio

property total_predict_batches_current_dataloader: Union[int, float]

预测批次的总数,对于当前的数据加载器,这可能会随着每个epoch而变化。

使用此项来设置进度条中的总迭代次数。如果预测数据加载器的大小是无限的,则可以返回 inf

返回类型

Union[int, float]

property total_test_batches_current_dataloader: Union[int, float]

当前数据加载器中,测试批次的总数,可能会随着每个epoch而变化。

使用此项来设置进度条中的总迭代次数。如果测试数据加载器的大小是无限的,则可以返回 inf

返回类型

Union[int, float]

property total_train_batches: Union[int, float]

训练批次的总数,可能会随着每个epoch而变化。

使用此项来设置进度条中的总迭代次数。如果训练数据加载器的大小是无限的,则可以返回 inf

返回类型

Union[int, float]

property total_val_batches: Union[int, float]

验证批次的总数,对于所有验证数据加载器,这可能会随着每个epoch而变化。

使用此项来设置进度条中的总迭代次数。如果预测数据加载器的大小是无限的,则可以返回 inf

返回类型

Union[int, float]

property total_val_batches_current_dataloader: Union[int, float]

验证批次的总数,对于当前的数据加载器,这可能会随着每个epoch而变化。

使用此项来设置进度条中的总迭代次数。如果验证数据加载器的大小是无限的,则可以返回 inf

返回类型

Union[int, float]

property train_description: str
返回类型

str

property train_progress_bar: tqdm_asyncio
返回类型

tqdm_asyncio

property trainer: Trainer
返回类型

Trainer

property val_progress_bar: tqdm_asyncio
返回类型

tqdm_asyncio

property validation_description: str
返回类型

str