! [ -e /content ] && pip install -Uqq fastai # 在Colab上升级fastai
回调函数
from __future__ import annotations
from fastai.data.all import *
from fastai.optimizer import *
from fastai.losses import BaseLoss
from nbdev.showdoc import *
= ['CancelStepException','CancelBackwardException','CancelFitException','CancelEpochException','CancelTrainException','CancelValidException','CancelBatchException'] _all_
学习者的基本回调函数
事件
回调可以在以下任意时间发生:创建后 在拟合前 在每个纪元前 在训练前 在批处理前 预测后 损失后 反向传播前 取消反向传播后 反向传播后 步骤前 取消步骤后 步骤后 取消批处理后 批处理后 取消训练后 训练后 验证前 取消验证后 验证后 取消纪元后 纪元后 取消拟合后 拟合后。
= L.split('after_create before_fit before_epoch before_train before_batch after_pred after_loss \
_events before_backward after_cancel_backward after_backward before_step after_cancel_step after_step \
after_cancel_batch after_batch after_cancel_train after_train before_validate after_cancel_validate \
after_validate after_cancel_epoch after_epoch after_cancel_fit after_fit')
'event', **_events.map_dict(),
mk_class(="All possible events as attributes to get tab-completion and typo-proofing") doc
= ['event'] _all_
='event', title_level=3) show_doc(event, name
class
event
[source]
event
(*args
, **kwargs
)
All possible events as attributes to get tab-completion and typo-proofing
为了确保您正在引用一个存在的事件(即回调被调用时的某个时刻的名称),并获得事件名称的标签自动完成,使用event
:
'before_step') test_eq(event.before_step,
回调 -
= "before_batch after_pred after_loss before_backward after_cancel_backward after_backward before_step after_step after_cancel_batch after_batch".split() _inner_loop
= dict(
_ex_docs ="Skip the rest of this batch and go to `after_batch`",
CancelBatchException="Skip the rest of the training part of the epoch and go to `after_train`",
CancelTrainException="Skip the rest of the validation part of the epoch and go to `after_validate`",
CancelValidException="Skip the rest of this epoch and go to `after_epoch`",
CancelEpochException="Skip stepping the optimizer",
CancelStepException ="Skip the backward pass and go to `after_backward`",
CancelBackwardException="Interrupts training and go to `after_fit`")
CancelFitException
for c,d in _ex_docs.items(): mk_class(c,sup=Exception,doc=d)
@funcs_kwargs(as_method=True)
class Callback(Stateful,GetAttr):
"Basic class handling tweaks of the training loop by changing a `Learner` in various events"
= 0,'learn',None,True,True,True
order,_default,learn,run,run_train,run_valid = _events
_methods
def __init__(self, **kwargs): assert not kwargs, f'Passed unknown events: {kwargs}'
def __repr__(self): return type(self).__name__
def __call__(self, event_name):
"Call `self.{event_name}` if it's defined"
= (event_name not in _inner_loop or (self.run_train and getattr(self, 'training', True)) or
_run self.run_valid and not getattr(self, 'training', False)))
(= None
res if self.run and _run:
try: res = getcallable(self, event_name)()
except (CancelBatchException, CancelBackwardException, CancelEpochException, CancelFitException, CancelStepException, CancelTrainException, CancelValidException): raise
except Exception as e: raise modify_exception(e, f'Exception occured in `{self.__class__.__name__}` when calling event `{event_name}`:\n\t{e.args[0]}', replace=True)
if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
return res
def __setattr__(self, name, value):
"Set an attribute for a `Callback`"
if hasattr(self.learn,name):
f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
warn(super().__setattr__(name, value)
@property
def name(self):
"Name of the `Callback`, camel-cased and with '*Callback*' removed"
return class2attr(self, 'Callback')
训练循环在Learner
中定义,下面涉及一组最基本的指令:在数据上循环,我们:
- 从输入计算模型的输出
- 计算该输出与期望目标之间的损失
- 计算此损失相对于所有模型参数的梯度
- 相应地更新参数
- 清零所有梯度
对该训练循环的任何调整都在Callback
中定义,以避免使训练循环的代码过于复杂,同时方便不同技术的组合(因为它们将在不同的回调中定义)。回调可以在以下事件上实现操作:
after_create
:在创建Learner
后调用before_fit
:在开始训练或推断之前调用,适合进行初始设置。before_epoch
:在每个epoch开始时调用,适用于需要在每个epoch重置的任何行为。before_train
:在epoch的训练部分开始时调用。before_batch
:在每个batch开始时调用,刚好在抽取该batch之后。可用于进行batch必要的准备(如超参数调度)或在输入/目标进入模型之前改变输入/目标(例如,使用mixup等技术进行输入的变更)。after_pred
:在计算batch的模型输出之后调用。可以用于在将输出传递到损失之前改变该输出。after_loss
:在计算损失之后调用,但在反向传播之前调用。可以用于在损失中添加任何惩罚(例如,RNN训练中的AR或TAR)。before_backward
:在计算损失后调用,但仅在训练模式下(即,当将使用反向传播时)。after_backward
:在反向传播之后调用,但在参数更新之前。通常应该使用before_step
替代。before_step
:在反向传播之后调用,但在参数更新之前。可用于在该更新之前对梯度进行任何更改(例如梯度裁剪)。after_step
:在步骤之后调用,并在梯度被清零之前调用。after_batch
:在batch结束时调用,用于在下一个batch之前进行任何清理。after_train
:在一个epoch的训练阶段结束时调用。before_validate
:在一个epoch的验证阶段开始时调用,适用于任何特定于验证的设置。after_validate
:在一个epoch的验证部分结束时调用。after_epoch
:在一个epoch结束时调用,用于在下一个epoch之前进行任何清理。after_fit
:在训练结束时调用,用于最终的清理。
__call__) show_doc(Callback.
定义回调的一种方法是通过子类化:
class _T(Callback):
def call_me(self): return "maybe"
"call_me"), "maybe") test_eq(_T()(
另一种方法是将回调函数传递给构造函数:
def cb(self): return "maybe"
= Callback(before_fit=cb)
_t "maybe") test_eq(_t(event.before_fit),
回调
提供了一种快捷方式,可以避免每次都要写self.learn.bla
来获取我们所需的任何bla
属性;只需写self.bla
即可。这仅适用于获取属性,而不适用于设置属性。
'TstLearner', 'a')
mk_class(
class TstCallback(Callback):
def batch_begin(self): print(self.a)
= TstLearner(1),TstCallback()
learn,cb = learn
cb.learn lambda: cb('batch_begin'), "1") test_stdout(
如果你想改变一个属性的值,你必须使用 self.learn.bla
,而不是 self.bla
。在下面的示例中,self.a += 1
在回调中创建了一个值为 2 的 a
属性,而不是将学习者的 a
设置为 2。它还发出了一个警告,表示可能存在问题:
learn.a
1
class TstCallback(Callback):
def batch_begin(self): self.a += 1
= TstLearner(1),TstCallback()
learn,cb = learn
cb.learn 'batch_begin')
cb(2)
test_eq(cb.a, 1) test_eq(cb.learn.a,
/tmp/ipykernel_5201/1369389649.py:29: UserWarning: You are shadowing an attribute (a) that exists in the learner. Use `self.learn.a` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
一个正确的版本需要写 self.learn.a = self.a + 1
:
class TstCallback(Callback):
def batch_begin(self): self.learn.a = self.a + 1
= TstLearner(1),TstCallback()
learn,cb = learn
cb.learn 'batch_begin')
cb(2) test_eq(cb.learn.a,
class TstCallback(Callback):
def batch_begin(self): self.learn.a = 1 + "a"
= TstLearner(1),TstCallback()
learn,cb = learn
cb.learn with ExceptionExpected(TypeError, regex=" in `TstCallback` when calling event `batch_begin`"):
'batch_begin') cb(
='Callback.name') show_doc(Callback.name, name
'tst')
test_eq(TstCallback().name, class ComplicatedNameCallback(Callback): pass
'complicated_name') test_eq(ComplicatedNameCallback().name,
TrainEvalCallback -
class TrainEvalCallback(Callback):
"`Callback` that tracks the number of iterations done and properly sets training/eval mode"
= -10,False
order,run_valid def after_create(self): self.learn.n_epoch = 1
def before_fit(self):
"Set the iter and epoch counters to 0, put the model and the right device"
self.learn.epoch,self.learn.loss = 0,tensor(0.)
self.learn.train_iter,self.learn.pct_train = 0,0.
= getattr(self.dls, 'device', default_device())
device self.model.to(device)
if isinstance(self.loss_func, (nn.Module, BaseLoss)): self.loss_func.to(device)
if hasattr(self.model, 'reset'): self.model.reset()
def after_batch(self):
"Update the iter counter (in training mode)"
self.learn.pct_train += 1./(self.n_iter*self.n_epoch)
self.learn.train_iter += 1
def before_train(self):
"Set the model to training mode"
self.learn.pct_train=self.epoch/self.n_epoch
self.model.train()
self.learn.training=True
def before_validate(self):
"Set the model to validation mode"
self.model.eval()
self.learn.training=False
=3) show_doc(TrainEvalCallback, title_level
class
TrainEvalCallback
[source]
TrainEvalCallback
(after_create
=None
,before_fit
=None
,before_epoch
=None
,before_train
=None
,before_batch
=None
,after_pred
=None
,after_loss
=None
,before_backward
=None
,before_step
=None
,after_cancel_step
=None
,after_step
=None
,after_cancel_batch
=None
,after_batch
=None
,after_cancel_train
=None
,after_train
=None
,before_validate
=None
,after_cancel_validate
=None
,after_validate
=None
,after_cancel_epoch
=None
,after_epoch
=None
,after_cancel_fit
=None
,after_fit
=None
) ::Callback
Callback
that tracks the number of iterations done and properly sets training/eval mode
这个 Callback
在每个 Learner
初始化时会自动添加。
# 在Learner.fit中测试TrainEvalCallback的代码如下
if not hasattr(defaults, 'callbacks'): defaults.callbacks = [TrainEvalCallback]
可用于回调的属性
在编写回调时,Learner
的以下属性是可用的:
model
:用于训练/验证的模型dls
:基础DataLoaders
loss_func
:使用的损失函数opt
:用于更新模型参数的优化器opt_func
:用于创建优化器的函数cbs
:包含所有Callback
的列表dl
:当前用于迭代的DataLoader
x
/xb
:最后从self.dl
中抽取的输入(可能被回调修改)。xb
始终是一个元组(可能只有一个元素),而x
是去元组化的。您只能给xb
赋值。y
/yb
:最后从self.dl
中抽取的目标(可能被回调修改)。yb
始终是一个元组(可能只有一个元素),而y
是去元组化的。您只能给yb
赋值。pred
:最后从self.model
得到的预测(可能被回调修改)loss_grad
:最后计算出的损失(可能被回调修改)loss
:用于记录的loss_grad
的克隆n_epoch
:本次训练中的轮数n_iter
:当前self.dl
中的迭代数量epoch
:当前轮次索引(从 0 到n_epoch-1
)iter
:当前在self.dl
中的迭代索引(从 0 到n_iter-1
)
以下属性由 TrainEvalCallback
添加,除非您特别删除了该回调,否则应该是可用的:
train_iter
:自本次训练开始以来完成的训练迭代次数pct_train
:从 0 到 1,已完成的训练迭代的百分比training
:指示我们是否处于训练模式的标志
以下属性由 Recorder
添加,除非您特别删除了该回调,否则应该是可用的:
smooth_loss
:训练损失的指数加权平均值
回调控制流
我们可能希望跳过训练循环的某些步骤:例如,在梯度累积中,我们不总是希望执行梯度的更新/置零。在学习率查找测试中,我们不希望执行一个周期的验证阶段。或者,如果我们采用早停策略进行训练,我们希望能够完全中断训练循环。
这可以通过引发训练循环将要查找(并正确捕获)的特定异常来实现。
=3) show_doc(CancelStepException, title_level
class
CancelStepException
[source]
CancelStepException
(*args
, **kwargs
) ::Exception
Skip stepping the optimizer
=3) show_doc(CancelBatchException, title_level
class
CancelBatchException
[source]
CancelBatchException
(*args
, **kwargs
) ::Exception
Skip the rest of this batch and go to after_batch
=3) show_doc(CancelBackwardException, title_level
class
CancelBackwardException
[source]
CancelBackwardException
(*args
, **kwargs
) ::Exception
Skip the backward pass and go to after_backward
=3) show_doc(CancelTrainException, title_level
class
CancelTrainException
[source]
CancelTrainException
(*args
, **kwargs
) ::Exception
Skip the rest of the training part of the epoch and go to after_train
=3) show_doc(CancelValidException, title_level
class
CancelValidException
[source]
CancelValidException
(*args
, **kwargs
) ::Exception
Skip the rest of the validation part of the epoch and go to after_validate
=3) show_doc(CancelEpochException, title_level
class
CancelEpochException
[source]
CancelEpochException
(*args
, **kwargs
) ::Exception
Skip the rest of this epoch and go to after_epoch
=3) show_doc(CancelFitException, title_level
class
CancelFitException
[source]
CancelFitException
(*args
, **kwargs
) ::Exception
Interrupts training and go to after_fit
您可以检测到其中一个异常的发生,并添加在以下事件后立即执行的代码:
after_cancel_batch
:在CancelBatchException
之后立即触达,然后继续执行after_batch
after_cancel_train
:在CancelTrainException
之后立即触达,然后继续执行after_epoch
after_cancel_valid
:在CancelValidException
之后立即触达,然后继续执行after_epoch
after_cancel_epoch
:在CancelEpochException
之后立即触达,然后继续执行after_epoch
after_cancel_fit
:在CancelFitException
之后立即触达,然后继续执行after_fit
收集和获取预测回调 -
class GatherPredsCallback(Callback):
"`Callback` that returns all predictions and targets, optionally `with_input` or `with_loss`"
=('preds','targets','inputs','losses')
_stateattrsdef __init__(self,
bool=False, # 是否返回输入
with_input:bool=False, # 是否返还损失
with_loss:=None, # 保存预测结果的路径
save_preds:Path=None, # 保存目标的路径
save_targs:Pathbool=True, # 是否返回预测
with_preds:bool=True, # 是否返回目标
with_targs:int=0, # 要连接的返回张量的维度
concat_dim:int=2 # 用于保存预测结果和目标的协议
pickle_protocol:
):
store_attr()
def before_batch(self):
"If `with_input`, detach batch inputs"
if self.with_input: self.inputs.append((self.learn.to_detach(self.xb)))
def before_validate(self):
"Initialize containers"
self.preds,self.targets = [],[]
if self.with_input: self.inputs = []
if self.with_loss: self.losses = []
def after_batch(self):
"Save predictions, targets and potentially losses"
if not hasattr(self, 'pred'): return
= self.learn.to_detach(self.pred),self.learn.to_detach(self.yb)
preds,targs if self.with_preds: self.preds.append(preds)
if self.with_targs: self.targets.append(targs)
if self.save_preds is not None:
self.save_preds/str(self.iter), pickle_protocol=self.pickle_protocol)
torch.save(preds, if self.save_targs is not None:
0], self.save_targs/str(self.iter), pickle_protocol=self.pickle_protocol)
torch.save(targs[if self.with_loss:
= find_bs(self.yb)
bs = self.loss if self.loss.numel() == bs else self.loss.view(bs,-1).mean(1)
loss self.losses.append(self.learn.to_detach(loss))
def after_validate(self):
"Concatenate all recorded tensors"
if not hasattr(self, 'preds'): return
if self.with_input: self.inputs = detuplify(to_concat(self.inputs, dim=self.concat_dim))
if self.with_preds: self.preds = detuplify(to_concat(self.preds, dim=self.concat_dim))
if self.with_targs: self.targets = detuplify(to_concat(self.targets, dim=self.concat_dim))
if self.with_loss: self.losses = to_concat(self.losses)
def all_tensors(self) -> (Tensor, list):
"Returns all recorded tensors in the order [inputs, preds, targets, losses]"
= [self.preds if self.with_preds else None, self.targets if self.with_targs else None]
res if self.with_input: res = [self.inputs] + res
if self.with_loss: res.append(self.losses)
return res
=3) show_doc(GatherPredsCallback, title_level
class
GatherPredsCallback
[source]
GatherPredsCallback
(with_input
:bool
=False
,with_loss
:bool
=False
,save_preds
:PathLike'>)
=None
,save_targs
:PathLike'>)
=None
,with_preds
:bool
=True
,with_targs
:bool
=True
,concat_dim
:int
=0
,pickle_protocol
:int
=2
) ::Callback
Callback
that returns all predictions and targets, optionally with_input
or with_loss
Type | Default | Details | |
---|---|---|---|
with_input |
bool |
False |
Whether to return inputs |
with_loss |
bool |
False |
Whether to return losses |
save_preds |
(str, PathLike) |
None |
Path to save predictions |
save_targs |
(str, PathLike) |
None |
Path to save targets |
with_preds |
bool |
True |
Whether to return predictions |
with_targs |
bool |
True |
Whether to return targets |
concat_dim |
int |
0 |
Dimension to concatenate returned tensors |
pickle_protocol |
int |
2 |
Pickle protocol used to save predictions and targets |
class FetchPredsCallback(Callback):
"A callback to fetch predictions during the training loop"
= True
remove_on_fetch def __init__(self,
int=1, # 数据集索引,0 表示训练集,1 表示验证集,若 `dl` 不存在则使用此索引
ds_idx:=None, # 用于获取`Learner`预测结果的`DataLoader`
dl:DataLoaderbool=False, # 是否在`GatherPredsCallback`中返回输入
with_input:bool=False, # 是否返回解码后的预测结果
with_decoded:|MutableSequence=None, # 从`Learner`中暂时移除`Callback`
cbs:Callbackbool=True # 是否对预测结果进行排序
reorder:
):self.cbs = L(cbs)
'ds_idx,dl,with_input,with_decoded,reorder')
store_attr(
def after_validate(self):
"Fetch predictions from `Learner` without `self.cbs` and `remove_on_fetch` callbacks"
= L(cb for cb in self.learn.cbs if getattr(cb, 'remove_on_fetch', False))
to_rm with self.learn.removed_cbs(to_rm + self.cbs) as learn:
self.preds = learn.get_preds(ds_idx=self.ds_idx, dl=self.dl,
=self.with_input, with_decoded=self.with_decoded, inner=True, reorder=self.reorder) with_input
=3) show_doc(FetchPredsCallback, title_level
class
FetchPredsCallback
[source]
FetchPredsCallback
(ds_idx
:int
=1
,dl
:DataLoader
=None
,with_input
:bool
=False
,with_decoded
:bool
=False
,cbs
:list
=None
,reorder
:bool
=True
) ::Callback
A callback to fetch predictions during the training loop
Type | Default | Details | |
---|---|---|---|
ds_idx |
int |
1 |
Index of dataset, 0 for train, 1 for valid, used if dl is not present |
dl |
DataLoader |
None |
DataLoader used for fetching Learner predictions |
with_input |
bool |
False |
Whether to return inputs in GatherPredsCallback |
with_decoded |
bool |
False |
Whether to return predicted classes |
cbs |
list |
None |
Callback list to add to the Learner |
reorder |
bool |
True |
Whether to sort prediction results |
导出 -
from nbdev import nbdev_export
nbdev_export()
Converted 00_torch_core.ipynb.
Converted 01_layers.ipynb.
Converted 01a_losses.ipynb.
Converted 02_data.load.ipynb.
Converted 03_data.core.ipynb.
Converted 04_data.external.ipynb.
Converted 05_data.transforms.ipynb.
Converted 06_data.block.ipynb.
Converted 07_vision.core.ipynb.
Converted 08_vision.data.ipynb.
Converted 09_vision.augment.ipynb.
Converted 09b_vision.utils.ipynb.
Converted 09c_vision.widgets.ipynb.
Converted 10_tutorial.pets.ipynb.
Converted 10b_tutorial.albumentations.ipynb.
Converted 11_vision.models.xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_callback.core.ipynb.
Converted 13a_learner.ipynb.
Converted 13b_metrics.ipynb.
Converted 14_callback.schedule.ipynb.
Converted 14a_callback.data.ipynb.
Converted 15_callback.hook.ipynb.
Converted 15a_vision.models.unet.ipynb.
Converted 16_callback.progress.ipynb.
Converted 17_callback.tracker.ipynb.
Converted 18_callback.fp16.ipynb.
Converted 18a_callback.training.ipynb.
Converted 18b_callback.preds.ipynb.
Converted 19_callback.mixup.ipynb.
Converted 20_interpret.ipynb.
Converted 20a_distributed.ipynb.
Converted 21_vision.learner.ipynb.
Converted 22_tutorial.imagenette.ipynb.
Converted 23_tutorial.vision.ipynb.
Converted 24_tutorial.siamese.ipynb.
Converted 24_vision.gan.ipynb.
Converted 30_text.core.ipynb.
Converted 31_text.data.ipynb.
Converted 32_text.models.awdlstm.ipynb.
Converted 33_text.models.core.ipynb.
Converted 34_callback.rnn.ipynb.
Converted 35_tutorial.wikitext.ipynb.
Converted 36_text.models.qrnn.ipynb.
Converted 37_text.learner.ipynb.
Converted 38_tutorial.text.ipynb.
Converted 39_tutorial.transformers.ipynb.
Converted 40_tabular.core.ipynb.
Converted 41_tabular.data.ipynb.
Converted 42_tabular.model.ipynb.
Converted 43_tabular.learner.ipynb.
Converted 44_tutorial.tabular.ipynb.
Converted 45_collab.ipynb.
Converted 46_tutorial.collab.ipynb.
Converted 50_tutorial.datablock.ipynb.
Converted 60_medical.imaging.ipynb.
Converted 61_tutorial.medical_imaging.ipynb.
Converted 65_medical.text.ipynb.
Converted 70_callback.wandb.ipynb.
Converted 71_callback.tensorboard.ipynb.
Converted 72_callback.neptune.ipynb.
Converted 73_callback.captum.ipynb.
Converted 97_test_utils.ipynb.
Converted 99_pytorch_doc.ipynb.
Converted dev-setup.ipynb.
Converted index.ipynb.
Converted quick_start.ipynb.
Converted tutorial.ipynb.