# 默认导出:callback.training
from __future__ import annotations
from fastai.basics import *
from fastai.callback.progress import *
from fastai.callback.fp16 import *
from nbdev.showdoc import *
from fastai.test_utils import *
from fastai.vision.all import *
短期周期回调 -
class ShortEpochCallback(Callback):
"Fit just `pct` of an epoch, then stop"
def __init__(self,pct=0.01,short_valid=True): self.pct,self.short_valid = pct,short_valid
def after_batch(self):
if self.iter/self.n_iter < self.pct: return
if self.training: raise CancelTrainException
if self.short_valid: raise CancelValidException
= synth_learner()
learn 1, cbs=ShortEpochCallback()) learn.fit(
epoch | train_loss | valid_loss | time |
0 | 00:00 |
= synth_learner()
learn 1, cbs=ShortEpochCallback(short_valid=False)) learn.fit(
epoch | train_loss | valid_loss | time |
0 | 8.432135 | 00:00 |
梯度累积 -
class GradientAccumulation(Callback):
"Accumulate gradients before updating weights"
= MixedPrecision.order-4,False
order,run_valid def __init__(self, n_acc=32): store_attr()
def before_fit(self): self.count=0
def after_loss(self): self.learn.loss_grad /= self.n_acc/find_bs(self.learn.yb)
def before_step(self):
"Skip weight update if we have not seen enough items"
self.learn.loss_grad *= self.n_acc/find_bs(self.learn.yb) # 日志正确损失
self.count += find_bs(self.learn.yb)
if self.count<self.n_acc: raise CancelBatchException() # 跳过步骤/清零梯度
else: self.count=0
class GetGrads(Callback):
= False,GradientAccumulation.order+1
run_valid,order def before_step(self): self.grads=to_detach(L([p.grad.clone() for p in self.model.parameters()]))
def _test_acc(bs,n,cbs=None,cuda=False):
with no_random(99):
db= synth_learner(data=db,cbs=[GetGrads]+L(cbs))
learn 1, lr=0.01)
learn.fit(= learn.recorder.values[-1]
train,valid return train,valid,learn.get_grads.grads
= GradientAccumulation(n_acc=8)
= _test_acc(8,1)
train1,valid1,grads1 = _test_acc(1,8,acc_cb)
test_close(valid2, valid1) test_ne(train2, train1)
epoch | train_loss | valid_loss | time |
0 | 0.834062 | 0.295950 | 00:00 |
epoch | train_loss | valid_loss | time |
0 | 0.824550 | 0.295950 | 00:00 |
= MixedPrecision(init_scale=1024)
fp16_cb = _test_acc(8,1, fp16_cb, cuda=True)
train1,valid1,grads1 = _test_acc(1,8, [acc_cb,fp16_cb], cuda=True)
train2,valid2,grads2 =0.01)
test_close(grads2,grads1, eps
test_close(valid2, valid1) test_ne(train2, train1)
epoch | train_loss | valid_loss | time |
0 | 0.834062 | 0.295950 | 00:00 |
epoch | train_loss | valid_loss | time |
0 | 0.824550 | 0.295950 | 00:00 |
= synth_learner()
learn 1, lr=0.01, cbs=GradientAccumulation(n_acc=1000))
learn.fit(# 确保valid_loss没有变化
assert learn.recorder.values[-1][1] == learn.recorder.values[0][1]
epoch | train_loss | valid_loss | time |
0 | 20.987558 | 26.849480 | 00:00 |
梯度剪切 -
class GradientClip(Callback):
"Clip norm of gradients"
orderdef __init__(self,max_norm:float=1., norm_type:float=2.0): store_attr()
def before_step(self): nn.utils.clip_grad_norm_(self.parameters(), self.max_norm, self.norm_type)
= MixedPrecision() fp16
set_seed(= synth_learner(lr=1.1, cuda=True)
learn 3, cbs=fp16) learn.fit(
epoch | train_loss | valid_loss | time |
0 | 38.214138 | 25.269005 | 00:00 |
1 | 377.145508 | 890.010376 | 00:00 |
2 | 839.392883 | 9965.747070 | 00:00 |
通过添加 GradientClip
回调,梯度的 norm_type
(默认值:2)会被限制为最多 max_norm
set_seed(= synth_learner(lr=1.1, cuda=True)
learn 3, cbs=[GradientClip,fp16]) learn.fit(
epoch | train_loss | valid_loss | time |
0 | 2.039428 | 2.372177 | 00:00 |
1 | 1.402425 | 0.300728 | 00:00 |
2 | 1.013548 | 0.332610 | 00:00 |
= (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
def set_bn_eval(m:nn.Module, use_eval=True)->None:
"Set bn layers in eval mode for all recursive children of `m`."
for l in m.children():
if isinstance(l, bn_types) and not next(l.parameters()).requires_grad:
if use_eval: l.eval()
else: l.train()
class BnFreeze(Callback):
run_after"Freeze moving average statistics in all non-trainable batchnorm layers."
def before_train(self):
self.model) set_bn_eval(
在这里不够,因为 BatchNorm
层默认是可训练的,并且正在跟踪批次的运行均值和标准差。为了使特征提取器完全匹配,您需要设置 train_bn=False
,并且这些统计数据也需要被冻结,这正是 BnFreeze
::: {#cell-26 .cell 0=‘缓’ 1=‘慢’}
= untar_data(URLs.MNIST_TINY)
path = ImageDataLoaders.from_folder(path, valid_pct=0.2) dls
`Learner` 来展示当仅使用 `train_bn=False` 时运行统计的不匹配情况... 我们首先通过创建一个
::: {#cell-28 .cell 0=‘缓’ 1=‘慢’}
= vision_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False) learn1
::: {#cell-30 .cell 0=‘缓’ 1=‘慢’}
= learn1.model[0][1].running_mean.clone() m
::: {#cell-32 .cell 0=‘缓’ 1=‘慢’}
1, lr=0.02)
learn1.fit(0][1].running_mean), m) test_ne(to_detach(learn1.model[
epoch | train_loss | valid_loss | time |
0 | 1.148303 | 0.739404 | 00:12 |
使用 BnFreeze
::: {#cell-34 .cell 0=‘缓’ 1=‘慢’}
= vision_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False, cbs=BnFreeze)
learn1 = learn1.model[0][1].running_mean.detach().clone()
m 1, lr=0.02)
learn1.fit(0][1].running_mean), m) test_eq(to_detach(learn1.model[
epoch | train_loss | valid_loss | time |
0 | 0.478594 | 0.270772 | 00:10 |
导出 -
from nbdev import nbdev_export
