训练回调

! [ -e /content ] && pip install -Uqq fastai  # 在Colab上升级fastai
# 默认导出:callback.training
from __future__ import annotations
from fastai.basics import *
from fastai.callback.progress import *
from fastai.callback.fp16 import *
from nbdev.showdoc import *
from fastai.test_utils import *
from fastai.vision.all import *

各种回调用于自定义训练行为

短期周期回调 -

class ShortEpochCallback(Callback):
    "Fit just `pct` of an epoch, then stop"
    def __init__(self,pct=0.01,short_valid=True): self.pct,self.short_valid = pct,short_valid
    def after_batch(self):
        if self.iter/self.n_iter < self.pct: return
        if self.training:    raise CancelTrainException
        if self.short_valid: raise CancelValidException
learn = synth_learner()
learn.fit(1, cbs=ShortEpochCallback())
epoch train_loss valid_loss time
0 00:00
learn = synth_learner()
learn.fit(1, cbs=ShortEpochCallback(short_valid=False))
epoch train_loss valid_loss time
0 8.432135 00:00

梯度累积 -

class GradientAccumulation(Callback):
    "Accumulate gradients before updating weights"
    order,run_valid = MixedPrecision.order-4,False
    def __init__(self, n_acc=32): store_attr()
    def before_fit(self): self.count=0
    def after_loss(self): self.learn.loss_grad /= self.n_acc/find_bs(self.learn.yb)
    def before_step(self):
        "Skip weight update if we have not seen enough items"
        self.learn.loss_grad *= self.n_acc/find_bs(self.learn.yb) # 日志正确损失
        self.count += find_bs(self.learn.yb)
        if self.count<self.n_acc: raise CancelBatchException() # 跳过步骤/清零梯度
        else: self.count=0
class GetGrads(Callback):
    run_valid,order = False,GradientAccumulation.order+1
    def before_step(self): self.grads=to_detach(L([p.grad.clone() for p in self.model.parameters()]))

def _test_acc(bs,n,cbs=None,cuda=False):
    with no_random(99): 
        db=synth_dbunch(bs=bs,n_train=n,n_valid=n,cuda=cuda)
        learn = synth_learner(data=db,cbs=[GetGrads]+L(cbs))
        learn.fit(1, lr=0.01)
        train,valid = learn.recorder.values[-1]
        return train,valid,learn.get_grads.grads

acc_cb = GradientAccumulation(n_acc=8)

train1,valid1,grads1 = _test_acc(8,1)
train2,valid2,grads2 = _test_acc(1,8,acc_cb)

#梯度应相同,有效损失相同,训练损失不同
test_close(grads2,grads1)
test_close(valid2, valid1)
test_ne(train2, train1)
epoch train_loss valid_loss time
0 0.834062 0.295950 00:00
epoch train_loss valid_loss time
0 0.824550 0.295950 00:00
#|cuda
fp16_cb = MixedPrecision(init_scale=1024)
train1,valid1,grads1 = _test_acc(8,1, fp16_cb, cuda=True)
train2,valid2,grads2 = _test_acc(1,8, [acc_cb,fp16_cb], cuda=True)
test_close(grads2,grads1, eps=0.01)
test_close(valid2, valid1)
test_ne(train2, train1)
epoch train_loss valid_loss time
0 0.834062 0.295950 00:00
epoch train_loss valid_loss time
0 0.824550 0.295950 00:00

当每次累积的步数大于批次数时,参数(因此验证损失)完全不变:

learn = synth_learner()
learn.fit(1, lr=0.01, cbs=GradientAccumulation(n_acc=1000))
# 确保valid_loss没有变化
assert learn.recorder.values[-1][1] == learn.recorder.values[0][1]
epoch train_loss valid_loss time
0 20.987558 26.849480 00:00

梯度剪切 -

class GradientClip(Callback):
    "Clip norm of gradients"
    order=MixedPrecision.order+1
    def __init__(self,max_norm:float=1., norm_type:float=2.0): store_attr()
    def before_step(self): nn.utils.clip_grad_norm_(self.parameters(), self.max_norm, self.norm_type)

通常情况下,如果我们使用的学习率过高,训练将会发散。这甚至发生在我们使用混合精度训练时,尽管通过动态损失缩放来避免无穷大,但仍然会发散:

fp16 = MixedPrecision()
set_seed(99)
learn = synth_learner(lr=1.1, cuda=True)
learn.fit(3, cbs=fp16)
epoch train_loss valid_loss time
0 38.214138 25.269005 00:00
1 377.145508 890.010376 00:00
2 839.392883 9965.747070 00:00

通过添加 GradientClip 回调,梯度的 norm_type(默认值:2)会被限制为最多 max_norm(默认值:1),这可以避免损失发散:

set_seed(99)
learn = synth_learner(lr=1.1, cuda=True)
learn.fit(3, cbs=[GradientClip,fp16])
epoch train_loss valid_loss time
0 2.039428 2.372177 00:00
1 1.402425 0.300728 00:00
2 1.013548 0.332610 00:00

BnFreeze

bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)

def set_bn_eval(m:nn.Module, use_eval=True)->None:
    "Set bn layers in eval mode for all recursive children of `m`."
    for l in m.children():
        if isinstance(l, bn_types) and not next(l.parameters()).requires_grad:
            if use_eval: l.eval()
            else:        l.train()
        set_bn_eval(l)

class BnFreeze(Callback):
    run_after=TrainEvalCallback
    "Freeze moving average statistics in all non-trainable batchnorm layers."
    def before_train(self):
        set_bn_eval(self.model)

BnFreeze 在您想要训练两个具有共同特征提取器/主体的独立模型时非常有用。模型唯一不同的部分是您为迁移学习附加的头部。

Learner.freeze() 在这里不够,因为 BatchNorm 层默认是可训练的,并且正在跟踪批次的运行均值和标准差。为了使特征提取器完全匹配,您需要设置 train_bn=False,并且这些统计数据也需要被冻结,这正是 BnFreeze 的功能。

::: {#cell-26 .cell 0=‘缓’ 1=‘慢’}

path = untar_data(URLs.MNIST_TINY)
dls  = ImageDataLoaders.from_folder(path, valid_pct=0.2)

:::

# 内存格式教程

在此教程中,我们将探索 PyTorch 中的内存格式,并了解何时使用不同的内存格式是合适的。我们将讨论使用 `channels_last``channels_first` 格式的优缺点,以及如何在模型和输入数据之间进行转换。

## 为什么选择不同的内存格式?

在深度学习中,输入数据通常是图像,每个图像都有三个维度:高度、宽度和通道。PyTorch 中可以选择两种内存格式:

- `channels_first`: 图像的形状为 `(batch_size, channels, height, width)`
- `channels_last`: 图像的形状为 `(batch_size, height, width, channels)`

对于某些操作,`channels_last` 格式在兼容性和性能上可能优于 `channels_first` 格式。特别是在处理大型图像数据时,调整内存格式可以显著提高性能。

## 创建一个示例

我们首先通过创建一个 `Learner` 来展示当仅使用 `train_bn=False` 时运行统计的不匹配情况...

::: {#cell-28 .cell 0=‘缓’ 1=‘慢’}

learn1 = vision_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False)

:::

…并获取第一个BatchNorm层,并存储它的运行均值:

::: {#cell-30 .cell 0=‘缓’ 1=‘慢’}

m = learn1.model[0][1].running_mean.clone()

:::

您可以看到现在运行均值已发生变化:

::: {#cell-32 .cell 0=‘缓’ 1=‘慢’}

learn1.fit(1, lr=0.02)
test_ne(to_detach(learn1.model[0][1].running_mean), m)
epoch train_loss valid_loss time
0 1.148303 0.739404 00:12

:::

使用 BnFreeze 回调时,运行统计数据在训练过程中不会改变。这通常对于从迁移学习中获得良好结果非常重要。

::: {#cell-34 .cell 0=‘缓’ 1=‘慢’}

learn1 = vision_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False, cbs=BnFreeze)
m = learn1.model[0][1].running_mean.detach().clone()
learn1.fit(1, lr=0.02)
test_eq(to_detach(learn1.model[0][1].running_mean), m)
epoch train_loss valid_loss time
0 0.478594 0.270772 00:10

:::

导出 -

from nbdev import nbdev_export
nbdev_export()
Converted 00_torch_core.ipynb.
Converted 01_layers.ipynb.
Converted 01a_losses.ipynb.
Converted 02_data.load.ipynb.
Converted 03_data.core.ipynb.
Converted 04_data.external.ipynb.
Converted 05_data.transforms.ipynb.
Converted 06_data.block.ipynb.
Converted 07_vision.core.ipynb.
Converted 08_vision.data.ipynb.
Converted 09_vision.augment.ipynb.
Converted 09b_vision.utils.ipynb.
Converted 09c_vision.widgets.ipynb.
Converted 10_tutorial.pets.ipynb.
Converted 10b_tutorial.albumentations.ipynb.
Converted 11_vision.models.xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_callback.core.ipynb.
Converted 13a_learner.ipynb.
Converted 13b_metrics.ipynb.
Converted 14_callback.schedule.ipynb.
Converted 14a_callback.data.ipynb.
Converted 15_callback.hook.ipynb.
Converted 15a_vision.models.unet.ipynb.
Converted 16_callback.progress.ipynb.
Converted 17_callback.tracker.ipynb.
Converted 18_callback.fp16.ipynb.
Converted 18a_callback.training.ipynb.
Converted 18b_callback.preds.ipynb.
Converted 19_callback.mixup.ipynb.
Converted 20_interpret.ipynb.
Converted 20a_distributed.ipynb.
Converted 21_vision.learner.ipynb.
Converted 22_tutorial.imagenette.ipynb.
Converted 23_tutorial.vision.ipynb.
Converted 24_tutorial.siamese.ipynb.
Converted 24_vision.gan.ipynb.
Converted 30_text.core.ipynb.
Converted 31_text.data.ipynb.
Converted 32_text.models.awdlstm.ipynb.
Converted 33_text.models.core.ipynb.
Converted 34_callback.rnn.ipynb.
Converted 35_tutorial.wikitext.ipynb.
Converted 36_text.models.qrnn.ipynb.
Converted 37_text.learner.ipynb.
Converted 38_tutorial.text.ipynb.
Converted 39_tutorial.transformers.ipynb.
Converted 40_tabular.core.ipynb.
Converted 41_tabular.data.ipynb.
Converted 42_tabular.model.ipynb.
Converted 43_tabular.learner.ipynb.
Converted 44_tutorial.tabular.ipynb.
Converted 45_collab.ipynb.
Converted 46_tutorial.collab.ipynb.
Converted 50_tutorial.datablock.ipynb.
Converted 60_medical.imaging.ipynb.
Converted 61_tutorial.medical_imaging.ipynb.
Converted 65_medical.text.ipynb.
Converted 70_callback.wandb.ipynb.
Converted 71_callback.tensorboard.ipynb.
Converted 72_callback.neptune.ipynb.
Converted 73_callback.captum.ipynb.
Converted 97_test_utils.ipynb.
Converted 99_pytorch_doc.ipynb.
Converted dev-setup.ipynb.
Converted index.ipynb.
Converted quick_start.ipynb.
Converted tutorial.ipynb.