from __future__ import annotations
from fastai.data.all import *
from fastai.optimizer import *
from fastai.losses import BaseLoss
from nbdev.showdoc import *
_all_ = ['CancelStepException','CancelBackwardException','CancelFitException','CancelEpochException','CancelTrainException','CancelValidException','CancelBatchException']



回调可以在以下任意时间发生:创建后 在拟合前 在每个纪元前 在训练前 在批处理前 预测后 损失后 反向传播前 取消反向传播后 反向传播后 步骤前 取消步骤后 步骤后 取消批处理后 批处理后 取消训练后 训练后 验证前 取消验证后 验证后 取消纪元后 纪元后 取消拟合后 拟合后

_events = L.split('after_create before_fit before_epoch before_train before_batch after_pred after_loss \
    before_backward after_cancel_backward after_backward before_step after_cancel_step after_step \
    after_cancel_batch after_batch after_cancel_train after_train before_validate after_cancel_validate \
    after_validate after_cancel_epoch after_epoch after_cancel_fit after_fit')

mk_class('event', **_events.map_dict(),
         doc="All possible events as attributes to get tab-completion and typo-proofing")
_all_ = ['event']
show_doc(event, name='event', title_level=3)

class event[source]

event(*args, **kwargs)

All possible events as attributes to get tab-completion and typo-proofing


test_eq(event.before_step, 'before_step')

回调 -

_inner_loop = "before_batch after_pred after_loss before_backward after_cancel_backward after_backward before_step after_step after_cancel_batch after_batch".split()
_ex_docs = dict(
    CancelBatchException="Skip the rest of this batch and go to `after_batch`",
    CancelTrainException="Skip the rest of the training part of the epoch and go to `after_train`",
    CancelValidException="Skip the rest of the validation part of the epoch and go to `after_validate`",
    CancelEpochException="Skip the rest of this epoch and go to `after_epoch`",
    CancelStepException ="Skip stepping the optimizer",
    CancelBackwardException="Skip the backward pass and go to `after_backward`",
    CancelFitException  ="Interrupts training and go to `after_fit`")

for c,d in _ex_docs.items(): mk_class(c,sup=Exception,doc=d)
class Callback(Stateful,GetAttr):
    "Basic class handling tweaks of the training loop by changing a `Learner` in various events"
    order,_default,learn,run,run_train,run_valid = 0,'learn',None,True,True,True
    _methods = _events

    def __init__(self, **kwargs): assert not kwargs, f'Passed unknown events: {kwargs}'
    def __repr__(self): return type(self).__name__

    def __call__(self, event_name):
        "Call `self.{event_name}` if it's defined"
        _run = (event_name not in _inner_loop or (self.run_train and getattr(self, 'training', True)) or
               (self.run_valid and not getattr(self, 'training', False)))
        res = None
        if self.run and _run: 
            try: res = getcallable(self, event_name)()
            except (CancelBatchException, CancelBackwardException, CancelEpochException, CancelFitException, CancelStepException, CancelTrainException, CancelValidException): raise
            except Exception as e: raise modify_exception(e, f'Exception occured in `{self.__class__.__name__}` when calling event `{event_name}`:\n\t{e.args[0]}', replace=True)
        if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
        return res

    def __setattr__(self, name, value):
        "Set an attribute for a `Callback`"
        if hasattr(self.learn,name):
            warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
        super().__setattr__(name, value)

    def name(self):
        "Name of the `Callback`, camel-cased and with '*Callback*' removed"
        return class2attr(self, 'Callback')


  • 从输入计算模型的输出
  • 计算该输出与期望目标之间的损失
  • 计算此损失相对于所有模型参数的梯度
  • 相应地更新参数
  • 清零所有梯度


  • after_create:在创建Learner后调用
  • before_fit:在开始训练或推断之前调用,适合进行初始设置。
  • before_epoch:在每个epoch开始时调用,适用于需要在每个epoch重置的任何行为。
  • before_train:在epoch的训练部分开始时调用。
  • before_batch:在每个batch开始时调用,刚好在抽取该batch之后。可用于进行batch必要的准备(如超参数调度)或在输入/目标进入模型之前改变输入/目标(例如,使用mixup等技术进行输入的变更)。
  • after_pred:在计算batch的模型输出之后调用。可以用于在将输出传递到损失之前改变该输出。
  • after_loss:在计算损失之后调用,但在反向传播之前调用。可以用于在损失中添加任何惩罚(例如,RNN训练中的AR或TAR)。
  • before_backward:在计算损失后调用,但仅在训练模式下(即,当将使用反向传播时)。
  • after_backward:在反向传播之后调用,但在参数更新之前。通常应该使用before_step替代。
  • before_step:在反向传播之后调用,但在参数更新之前。可用于在该更新之前对梯度进行任何更改(例如梯度裁剪)。
  • after_step:在步骤之后调用,并在梯度被清零之前调用。
  • after_batch:在batch结束时调用,用于在下一个batch之前进行任何清理。
  • after_train:在一个epoch的训练阶段结束时调用。
  • before_validate:在一个epoch的验证阶段开始时调用,适用于任何特定于验证的设置。
  • after_validate:在一个epoch的验证部分结束时调用。
  • after_epoch:在一个epoch结束时调用,用于在下一个epoch之前进行任何清理。
  • after_fit:在训练结束时调用,用于最终的清理。



Call self.{event_name} if it’s defined


class _T(Callback):
    def call_me(self): return "maybe"
test_eq(_T()("call_me"), "maybe")


def cb(self): return "maybe"
_t = Callback(before_fit=cb)
test_eq(_t(event.before_fit), "maybe")


mk_class('TstLearner', 'a')

class TstCallback(Callback):
    def batch_begin(self): print(self.a)

learn,cb = TstLearner(1),TstCallback()
cb.learn = learn
test_stdout(lambda: cb('batch_begin'), "1")

如果你想改变一个属性的值,你必须使用 self.learn.bla,而不是 self.bla。在下面的示例中,self.a += 1 在回调中创建了一个值为 2 的 a 属性,而不是将学习者的 a 设置为 2。它还发出了一个警告,表示可能存在问题:

class TstCallback(Callback):
    def batch_begin(self): self.a += 1

learn,cb = TstLearner(1),TstCallback()
cb.learn = learn
test_eq(cb.a, 2)
test_eq(cb.learn.a, 1)
/tmp/ipykernel_5201/1369389649.py:29: UserWarning: You are shadowing an attribute (a) that exists in the learner. Use `self.learn.a` to avoid this
  warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")

一个正确的版本需要写 self.learn.a = self.a + 1:

class TstCallback(Callback):
    def batch_begin(self): self.learn.a = self.a + 1

learn,cb = TstLearner(1),TstCallback()
cb.learn = learn
test_eq(cb.learn.a, 2)
class TstCallback(Callback):
    def batch_begin(self): self.learn.a = 1 + "a"
learn,cb = TstLearner(1),TstCallback()
cb.learn = learn
with ExceptionExpected(TypeError, regex=" in `TstCallback` when calling event `batch_begin`"):
show_doc(Callback.name, name='Callback.name')


Name of the Callback, camel-cased and with ‘Callback’ removed

test_eq(TstCallback().name, 'tst')
class ComplicatedNameCallback(Callback): pass
test_eq(ComplicatedNameCallback().name, 'complicated_name')

TrainEvalCallback -

class TrainEvalCallback(Callback):
    "`Callback` that tracks the number of iterations done and properly sets training/eval mode"
    order,run_valid = -10,False
    def after_create(self): self.learn.n_epoch = 1

    def before_fit(self):
        "Set the iter and epoch counters to 0, put the model and the right device"
        self.learn.epoch,self.learn.loss = 0,tensor(0.)
        self.learn.train_iter,self.learn.pct_train = 0,0.
        device = getattr(self.dls, 'device', default_device())
        if isinstance(self.loss_func, (nn.Module, BaseLoss)): self.loss_func.to(device)
        if hasattr(self.model, 'reset'): self.model.reset()

    def after_batch(self):
        "Update the iter counter (in training mode)"
        self.learn.pct_train += 1./(self.n_iter*self.n_epoch)
        self.learn.train_iter += 1

    def before_train(self):
        "Set the model to training mode"

    def before_validate(self):
        "Set the model to validation mode"
show_doc(TrainEvalCallback, title_level=3)

class TrainEvalCallback[source]

TrainEvalCallback(after_create=None, before_fit=None, before_epoch=None, before_train=None, before_batch=None, after_pred=None, after_loss=None, before_backward=None, before_step=None, after_cancel_step=None, after_step=None, after_cancel_batch=None, after_batch=None, after_cancel_train=None, after_train=None, before_validate=None, after_cancel_validate=None, after_validate=None, after_cancel_epoch=None, after_epoch=None, after_cancel_fit=None, after_fit=None) :: Callback

Callback that tracks the number of iterations done and properly sets training/eval mode

这个 Callback 在每个 Learner 初始化时会自动添加。

# 在Learner.fit中测试TrainEvalCallback的代码如下
if not hasattr(defaults, 'callbacks'): defaults.callbacks = [TrainEvalCallback]


在编写回调时,Learner 的以下属性是可用的:

  • model:用于训练/验证的模型
  • dls:基础 DataLoaders
  • loss_func:使用的损失函数
  • opt:用于更新模型参数的优化器
  • opt_func:用于创建优化器的函数
  • cbs:包含所有 Callback 的列表
  • dl:当前用于迭代的 DataLoader
  • x/xb:最后从 self.dl 中抽取的输入(可能被回调修改)。xb 始终是一个元组(可能只有一个元素),而 x 是去元组化的。您只能给 xb 赋值。
  • y/yb:最后从 self.dl 中抽取的目标(可能被回调修改)。yb 始终是一个元组(可能只有一个元素),而 y 是去元组化的。您只能给 yb 赋值。
  • pred:最后从 self.model 得到的预测(可能被回调修改)
  • loss_grad:最后计算出的损失(可能被回调修改)
  • loss:用于记录的 loss_grad 的克隆
  • n_epoch:本次训练中的轮数
  • n_iter:当前 self.dl 中的迭代数量
  • epoch:当前轮次索引(从 0 到 n_epoch-1
  • iter:当前在 self.dl 中的迭代索引(从 0 到 n_iter-1

以下属性由 TrainEvalCallback 添加,除非您特别删除了该回调,否则应该是可用的:

  • train_iter:自本次训练开始以来完成的训练迭代次数
  • pct_train:从 0 到 1,已完成的训练迭代的百分比
  • training:指示我们是否处于训练模式的标志

以下属性由 Recorder 添加,除非您特别删除了该回调,否则应该是可用的:

  • smooth_loss:训练损失的指数加权平均值




show_doc(CancelStepException, title_level=3)

class CancelStepException[source]

CancelStepException(*args, **kwargs) :: Exception

Skip stepping the optimizer

show_doc(CancelBatchException, title_level=3)

class CancelBatchException[source]

CancelBatchException(*args, **kwargs) :: Exception

Skip the rest of this batch and go to after_batch

show_doc(CancelBackwardException, title_level=3)

class CancelBackwardException[source]

CancelBackwardException(*args, **kwargs) :: Exception

Skip the backward pass and go to after_backward

show_doc(CancelTrainException, title_level=3)

class CancelTrainException[source]

CancelTrainException(*args, **kwargs) :: Exception

Skip the rest of the training part of the epoch and go to after_train

show_doc(CancelValidException, title_level=3)

class CancelValidException[source]

CancelValidException(*args, **kwargs) :: Exception

Skip the rest of the validation part of the epoch and go to after_validate

show_doc(CancelEpochException, title_level=3)

class CancelEpochException[source]

CancelEpochException(*args, **kwargs) :: Exception

Skip the rest of this epoch and go to after_epoch

show_doc(CancelFitException, title_level=3)

class CancelFitException[source]

CancelFitException(*args, **kwargs) :: Exception

Interrupts training and go to after_fit


  • after_cancel_batch:在 CancelBatchException 之后立即触达,然后继续执行 after_batch
  • after_cancel_train:在 CancelTrainException 之后立即触达,然后继续执行 after_epoch
  • after_cancel_valid:在 CancelValidException 之后立即触达,然后继续执行 after_epoch
  • after_cancel_epoch:在 CancelEpochException 之后立即触达,然后继续执行 after_epoch
  • after_cancel_fit:在 CancelFitException 之后立即触达,然后继续执行 after_fit

收集和获取预测回调 -

class GatherPredsCallback(Callback):
    "`Callback` that returns all predictions and targets, optionally `with_input` or `with_loss`"
    def __init__(self,
        with_input:bool=False, # 是否返回输入
        with_loss:bool=False, # 是否返还损失
        save_preds:Path=None, # 保存预测结果的路径
        save_targs:Path=None, # 保存目标的路径
        with_preds:bool=True, # 是否返回预测
        with_targs:bool=True, # 是否返回目标
        concat_dim:int=0, # 要连接的返回张量的维度
        pickle_protocol:int=2 # 用于保存预测结果和目标的协议

    def before_batch(self):
        "If `with_input`, detach batch inputs"
        if self.with_input: self.inputs.append((self.learn.to_detach(self.xb)))

    def before_validate(self):
        "Initialize containers"
        self.preds,self.targets = [],[]
        if self.with_input: self.inputs = []
        if self.with_loss:  self.losses = []

    def after_batch(self):
        "Save predictions, targets and potentially losses"
        if not hasattr(self, 'pred'): return
        preds,targs = self.learn.to_detach(self.pred),self.learn.to_detach(self.yb)
        if self.with_preds: self.preds.append(preds)
        if self.with_targs: self.targets.append(targs)
        if self.save_preds is not None: 
            torch.save(preds, self.save_preds/str(self.iter), pickle_protocol=self.pickle_protocol)
        if self.save_targs is not None: 
            torch.save(targs[0], self.save_targs/str(self.iter), pickle_protocol=self.pickle_protocol)
        if self.with_loss:
            bs = find_bs(self.yb)
            loss = self.loss if self.loss.numel() == bs else self.loss.view(bs,-1).mean(1)

    def after_validate(self):
        "Concatenate all recorded tensors"
        if not hasattr(self, 'preds'): return
        if self.with_input: self.inputs  = detuplify(to_concat(self.inputs, dim=self.concat_dim))
        if self.with_preds: self.preds   = detuplify(to_concat(self.preds, dim=self.concat_dim))
        if self.with_targs: self.targets = detuplify(to_concat(self.targets, dim=self.concat_dim))
        if self.with_loss:  self.losses  = to_concat(self.losses)

    def all_tensors(self) -> (Tensor, list):
        "Returns all recorded tensors in the order [inputs, preds, targets, losses]"
        res = [self.preds if self.with_preds else None, self.targets if self.with_targs else None]
        if self.with_input: res = [self.inputs] + res
        if self.with_loss:  res.append(self.losses)
        return res
show_doc(GatherPredsCallback, title_level=3)

class GatherPredsCallback[source]

GatherPredsCallback(with_input:bool=False, with_loss:bool=False, save_preds:PathLike'>)=None, save_targs:PathLike'>)=None, with_preds:bool=True, with_targs:bool=True, concat_dim:int=0, pickle_protocol:int=2) :: Callback

Callback that returns all predictions and targets, optionally with_input or with_loss

Type Default Details
with_input bool False Whether to return inputs
with_loss bool False Whether to return losses
save_preds (str, PathLike) None Path to save predictions
save_targs (str, PathLike) None Path to save targets
with_preds bool True Whether to return predictions
with_targs bool True Whether to return targets
concat_dim int 0 Dimension to concatenate returned tensors
pickle_protocol int 2 Pickle protocol used to save predictions and targets
class FetchPredsCallback(Callback):
    "A callback to fetch predictions during the training loop"
    remove_on_fetch = True
    def __init__(self,
        ds_idx:int=1, # 数据集索引,0 表示训练集,1 表示验证集,若 `dl` 不存在则使用此索引
        dl:DataLoader=None, # 用于获取`Learner`预测结果的`DataLoader`
        with_input:bool=False, # 是否在`GatherPredsCallback`中返回输入
        with_decoded:bool=False, # 是否返回解码后的预测结果
        cbs:Callback|MutableSequence=None, # 从`Learner`中暂时移除`Callback`
        reorder:bool=True # 是否对预测结果进行排序
        self.cbs = L(cbs)

    def after_validate(self):
        "Fetch predictions from `Learner` without `self.cbs` and `remove_on_fetch` callbacks"
        to_rm = L(cb for cb in self.learn.cbs if getattr(cb, 'remove_on_fetch', False))
        with self.learn.removed_cbs(to_rm + self.cbs) as learn:
            self.preds = learn.get_preds(ds_idx=self.ds_idx, dl=self.dl,
                with_input=self.with_input, with_decoded=self.with_decoded, inner=True, reorder=self.reorder)
show_doc(FetchPredsCallback, title_level=3)

class FetchPredsCallback[source]

FetchPredsCallback(ds_idx:int=1, dl:DataLoader=None, with_input:bool=False, with_decoded:bool=False, cbs:list=None, reorder:bool=True) :: Callback

A callback to fetch predictions during the training loop

Type Default Details
ds_idx int 1 Index of dataset, 0 for train, 1 for valid, used if dl is not present
dl DataLoader None DataLoader used for fetching Learner predictions
with_input bool False Whether to return inputs in GatherPredsCallback
with_decoded bool False Whether to return predicted classes
cbs list None Callback list to add to the Learner
reorder bool True Whether to sort prediction results

导出 -

from nbdev import nbdev_export
