! [ -e /content ] && pip install -Uqq fastai # 在Colab上升级fastai
MixUp及其友好伙伴
from __future__ import annotations
from fastai.basics import *
from torch.distributions.beta import Beta
from nbdev.showdoc import *
from fastai.test_utils import *
可以将MixUp(及其变体)数据增强应用于您的训练的回调函数
from fastai.vision.all import *
def reduce_loss(
loss:Tensor, str='mean' # PyTorch 损失缩减
reduction:->Tensor:
)"Reduce the loss based on `reduction`"
return loss.mean() if reduction == 'mean' else loss.sum() if reduction == 'sum' else loss
class MixHandler(Callback):
"A handler class for implementing `MixUp` style scheduling"
= False
run_valid def __init__(self,
float=0.5 # 确定在范围 (0.,inf] 内的 `Beta` 分布
alpha:
):self.distrib = Beta(tensor(alpha), tensor(alpha))
def before_train(self):
"Determine whether to stack y"
self.stack_y = getattr(self.learn.loss_func, 'y_int', False)
if self.stack_y: self.old_lf,self.learn.loss_func = self.learn.loss_func,self.lf
def after_train(self):
"Set the loss function back to the previous loss"
if self.stack_y: self.learn.loss_func = self.old_lf
def after_cancel_train(self):
"If training is canceled, still set the loss function back"
self.after_train()
def after_cancel_fit(self):
"If fit is canceled, still set the loss function back"
self.after_train()
def lf(self, pred, *yb):
"lf is a loss function that applies the original loss function on both outputs based on `self.lam`"
if not self.training: return self.old_lf(pred, *yb)
with NoneReduce(self.old_lf) as lf:
= torch.lerp(lf(pred,*self.yb1), lf(pred,*yb), self.lam)
loss return reduce_loss(loss, getattr(self.old_lf, 'reduction', 'mean'))
大多数 Mix
变体将在批次上执行数据增强,因此要实现您的 Mix
,您应根据训练方案的要求调整 before_batch
事件。另外,如果需要不同的损失函数,您也应调整 lf
。alpha
被传递给 Beta
以创建一个采样器。
MixUp -
class MixUp(MixHandler):
"Implementation of https://arxiv.org/abs/1710.09412"
def __init__(self,
float=.4 # 确定在范围 (0.,inf] 内的 `Beta` 分布
alpha:
): super().__init__(alpha)
def before_batch(self):
"Blend xb and yb with another random item in a second batch (xb1,yb1) with `lam` weights"
= self.distrib.sample((self.y.size(0),)).squeeze().to(self.x.device)
lam = torch.stack([lam, 1-lam], 1)
lam self.lam = lam.max(1)[0]
= torch.randperm(self.y.size(0)).to(self.x.device)
shuffle self.yb1 = tuple(L(self.xb).itemgot(shuffle)),tuple(L(self.yb).itemgot(shuffle))
xb1,= len(self.x.size())
nx_dims self.learn.xb = tuple(L(xb1,self.xb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=nx_dims-1)))
if not self.stack_y:
= len(self.y.size())
ny_dims self.learn.yb = tuple(L(self.yb1,self.yb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=ny_dims-1)))
这是对mixup的修改实现,它始终会混合至少50%的原始图像。原始论文要求使用Beta分布,并将相同的alpha值传递给损失函数中的每个位置(alpha = beta = #)。与原始论文不同,这种mixup的实现选择lambda的最大值,这意味着如果采样的lambda值小于0.5(即原始图像的表示少于50%),则使用1-lambda。
两幅图像的混合由alpha
决定。
\(alpha=1.\):
- 所有介于0和1之间的值都有相等的机会被抽取。
- 两幅图像之间的任何混合都是可能的。
\(alpha<1.\):
- 更靠近0和1的值比接近0.5的值更有可能被抽取。
- 更有可能选择其中一幅图像,并带有少量另一幅图像。
\(alpha>1.\):
- 更靠近0.5的值比接近0或1的值更有可能。
- 更有可能使图像均匀混合。
首先,我们将看一个非常简约的示例,以展示我们的数据是如何使用 PETS
数据集生成的:
= untar_data(URLs.PETS)
path = r'([^/]+)_\d+.*$'
pat = get_image_files(path/'images')
fnames = [Resize(256, method='crop')]
item_tfms = [*aug_transforms(size=224), Normalize.from_stats(*imagenet_stats)]
batch_tfms = ImageDataLoaders.from_name_re(path, fnames, pat, bs=64, item_tfms=item_tfms,
dls =batch_tfms) batch_tfms
我们可以通过在 fit
的 before_batch
阶段抓取我们的数据来检查 Callback
的结果,如下所示:
= MixUp(1.)
mixup with Learner(dls, nn.Linear(3,4), loss_func=CrossEntropyLossFlat(), cbs=mixup) as learn:
= 0,True
learn.epoch,learn.training = dls.train
learn.dl = dls.one_batch()
b
learn._split(b)'before_train')
learn('before_batch')
learn(
= plt.subplots(3,3, figsize=(9,9))
_,axs =(mixup.x,mixup.y), ctxs=axs.flatten()) dls.show_batch(b
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 00:00 |
0], mixup.x)
test_ne(b[1], mixup.y) test_eq(b[
我们可以看到,偶尔一张图像会与另一张“混合”。
我们如何进行训练?您可以将 Callback
直接传递给 Learner
,或者在您的 fit 函数中传递给 cbs
:
::: {#cell-20 .cell 0=‘缓’ 1=‘慢’}
= vision_learner(dls, resnet18, loss_func=CrossEntropyLossFlat(), metrics=[error_rate])
learn 1, cbs=mixup) learn.fit_one_cycle(
epoch | train_loss | valid_loss | error_rate | time |
---|---|---|---|---|
0 | 2.041960 | 0.495492 | 0.162382 | 00:12 |
:::
CutMix -
class CutMix(MixHandler):
"Implementation of https://arxiv.org/abs/1905.04899"
def __init__(self,
float=1. # 确定在范围 (0.,inf] 内的 `Beta` 分布
alpha:
):super().__init__(alpha)
def before_batch(self):
"Add `rand_bbox` patches with size based on `lam` and location chosen randomly."
= self.x.size()
bs, _, H, W self.lam = self.distrib.sample((1,)).to(self.x.device)
= torch.randperm(bs).to(self.x.device)
shuffle self.yb1 = self.x[shuffle], tuple((self.y[shuffle],))
xb1,= self.rand_bbox(W, H, self.lam)
x1, y1, x2, y2 self.learn.xb[0][..., y1:y2, x1:x2] = xb1[..., y1:y2, x1:x2]
self.lam = (1 - ((x2-x1)*(y2-y1))/float(W*H))
if not self.stack_y:
= len(self.y.size())
ny_dims self.learn.yb = tuple(L(self.yb1,self.yb).map_zip(torch.lerp,weight=unsqueeze(self.lam, n=ny_dims-1)))
def rand_bbox(self,
int, # 输入图像宽度
W:int, # 输入图像高度
H:# 从Beta分布中采样的lambda值,即张量([0.3647])。
lam:Tensor -> tuple: # 表示左上角像素位置和右下角像素位置
) "Give a bounding box location based on the size of the im and a weight"
= torch.sqrt(1. - lam).to(self.x.device)
cut_rat = torch.round(W * cut_rat).type(torch.long).to(self.x.device)
cut_w = torch.round(H * cut_rat).type(torch.long).to(self.x.device)
cut_h # 制服
= torch.randint(0, W, (1,)).to(self.x.device)
cx = torch.randint(0, H, (1,)).to(self.x.device)
cy = torch.clamp(cx - torch.div(cut_w, 2, rounding_mode='floor'), 0, W)
x1 = torch.clamp(cy - torch.div(cut_h, 2, rounding_mode='floor'), 0, H)
y1 = torch.clamp(cx + torch.div(cut_w, 2, rounding_mode='floor'), 0, W)
x2 = torch.clamp(cy + torch.div(cut_h, 2, rounding_mode='floor'), 0, H)
y2 return x1, y1, x2, y2
类似于 MixUp
,CutMix
将从两幅图像中随机切出一个框并进行交换。我们可以看一下下面的几个例子:
= CutMix(1.)
cutmix with Learner(dls, nn.Linear(3,4), loss_func=CrossEntropyLossFlat(), cbs=cutmix) as learn:
= 0,True
learn.epoch,learn.training = dls.train
learn.dl = dls.one_batch()
b
learn._split(b)'before_train')
learn('before_batch')
learn(
= plt.subplots(3,3, figsize=(9,9))
_,axs =(cutmix.x,cutmix.y), ctxs=axs.flatten()) dls.show_batch(b
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | 00:00 |
我们以完全相同的方式进行训练。
::: {#cell-26 .cell 0=‘缓’ 1=‘慢’}
= vision_learner(dls, resnet18, loss_func=CrossEntropyLossFlat(), metrics=[accuracy, error_rate])
learn 1, cbs=cutmix) learn.fit_one_cycle(
epoch | train_loss | valid_loss | accuracy | error_rate | time |
---|---|---|---|---|---|
0 | 3.440883 | 0.793059 | 0.769959 | 0.230041 | 00:12 |
:::
导出 -
from nbdev import nbdev_export
nbdev_export()
Converted 00_torch_core.ipynb.
Converted 01_layers.ipynb.
Converted 01a_losses.ipynb.
Converted 02_data.load.ipynb.
Converted 03_data.core.ipynb.
Converted 04_data.external.ipynb.
Converted 05_data.transforms.ipynb.
Converted 06_data.block.ipynb.
Converted 07_vision.core.ipynb.
Converted 08_vision.data.ipynb.
Converted 09_vision.augment.ipynb.
Converted 09b_vision.utils.ipynb.
Converted 09c_vision.widgets.ipynb.
Converted 10_tutorial.pets.ipynb.
Converted 10b_tutorial.albumentations.ipynb.
Converted 11_vision.models.xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_callback.core.ipynb.
Converted 13a_learner.ipynb.
Converted 13b_metrics.ipynb.
Converted 14_callback.schedule.ipynb.
Converted 14a_callback.data.ipynb.
Converted 15_callback.hook.ipynb.
Converted 15a_vision.models.unet.ipynb.
Converted 16_callback.progress.ipynb.
Converted 17_callback.tracker.ipynb.
Converted 18_callback.fp16.ipynb.
Converted 18a_callback.training.ipynb.
Converted 18b_callback.preds.ipynb.
Converted 19_callback.mixup.ipynb.
Converted 20_interpret.ipynb.
Converted 20a_distributed.ipynb.
Converted 21_vision.learner.ipynb.
Converted 22_tutorial.imagenette.ipynb.
Converted 23_tutorial.vision.ipynb.
Converted 24_tutorial.image_sequence.ipynb.
Converted 24_tutorial.siamese.ipynb.
Converted 24_vision.gan.ipynb.
Converted 30_text.core.ipynb.
Converted 31_text.data.ipynb.
Converted 32_text.models.awdlstm.ipynb.
Converted 33_text.models.core.ipynb.
Converted 34_callback.rnn.ipynb.
Converted 35_tutorial.wikitext.ipynb.
Converted 37_text.learner.ipynb.
Converted 38_tutorial.text.ipynb.
Converted 39_tutorial.transformers.ipynb.
Converted 40_tabular.core.ipynb.
Converted 41_tabular.data.ipynb.
Converted 42_tabular.model.ipynb.
Converted 43_tabular.learner.ipynb.
Converted 44_tutorial.tabular.ipynb.
Converted 45_collab.ipynb.
Converted 46_tutorial.collab.ipynb.
Converted 50_tutorial.datablock.ipynb.
Converted 60_medical.imaging.ipynb.
Converted 61_tutorial.medical_imaging.ipynb.
Converted 65_medical.text.ipynb.
Converted 70_callback.wandb.ipynb.
Converted 71_callback.tensorboard.ipynb.
Converted 72_callback.neptune.ipynb.
Converted 73_callback.captum.ipynb.
Converted 74_callback.azureml.ipynb.
Converted 97_test_utils.ipynb.
Converted 99_pytorch_doc.ipynb.
Converted dev-setup.ipynb.
Converted app_examples.ipynb.
Converted camvid.ipynb.
Converted migrating_catalyst.ipynb.
Converted migrating_ignite.ipynb.
Converted migrating_lightning.ipynb.
Converted migrating_pytorch.ipynb.
Converted migrating_pytorch_verbose.ipynb.
Converted ulmfit.ipynb.
Converted index.ipynb.
Converted index_original.ipynb.
Converted quick_start.ipynb.
Converted tutorial.ipynb.