! [ -e /content ] && pip install -Uqq fastai # 在Colab上升级fastai
动态 UNet
from __future__ import annotations
from fastai.torch_basics import *
from fastai.callback.hook import *
from nbdev.showdoc import *
使用 PixelShuffle ICNR 上采样的 Unet 模型,可以在任何预训练架构的基础上构建。
::: {#cell-6 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
def _get_sz_change_idxs(sizes):
"Get the indexes of the layers where the size of the activation changes."
= [size[-1] for size in sizes]
feature_szs = list(np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0])
sz_chg_idxs return sz_chg_idxs
:::
3,64,64], [16,64,64], [32,32,32], [16,32,32], [32,32,32], [16,16]]), [1,4])
test_eq(_get_sz_change_idxs([[3,64,64], [16,32,32], [32,32,32], [16,32,32], [32,16,16], [16,16]]), [0,3])
test_eq(_get_sz_change_idxs([[3,64,64]]), [])
test_eq(_get_sz_change_idxs([[3,64,64], [16,32,32]]), [0]) test_eq(_get_sz_change_idxs([[
::: {#cell-8 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class UnetBlock(Module):
"A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
@delegates(ConvLayer.__init__)
def __init__(self, up_in_c, x_in_c, hook, final_div=True, blur=False, act_cls=defaults.activation,
=False, init=nn.init.kaiming_normal_, norm_type=None, **kwargs):
self_attentionself.hook = hook
self.shuf = PixelShuffle_ICNR(up_in_c, up_in_c//2, blur=blur, act_cls=act_cls, norm_type=norm_type)
self.bn = BatchNorm(x_in_c)
= up_in_c//2 + x_in_c
ni = ni if final_div else ni//2
nf self.conv1 = ConvLayer(ni, nf, act_cls=act_cls, norm_type=norm_type, **kwargs)
self.conv2 = ConvLayer(nf, nf, act_cls=act_cls, norm_type=norm_type,
=SelfAttention(nf) if self_attention else None, **kwargs)
xtraself.relu = act_cls()
self.conv1, self.conv2), init)
apply_init(nn.Sequential(
def forward(self, up_in):
= self.hook.stored
s = self.shuf(up_in)
up_out = s.shape[-2:]
ssh if ssh != up_out.shape[-2:]:
= F.interpolate(up_out, s.shape[-2:], mode='nearest')
up_out = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
cat_x return self.conv2(self.conv1(cat_x))
:::
class ResizeToOrig(Module):
"Merge a shortcut with the result of the module by adding them or concatenating them if `dense=True`."
def __init__(self, mode='nearest'): self.mode = mode
def forward(self, x):
if x.orig.shape[-2:] != x.shape[-2:]:
= F.interpolate(x, x.orig.shape[-2:], mode=self.mode)
x return x
::: {#cell-10 .cell 0=‘e’ 1=‘x’ 2=‘p’ 3=‘o’ 4=‘r’ 5=‘t’}
class DynamicUnet(SequentialEx):
"Create a U-Net from a given architecture."
def __init__(self, encoder, n_out, img_size, blur=False, blur_final=True, self_attention=False,
=None, last_cross=True, bottle=False, act_cls=defaults.activation,
y_range=nn.init.kaiming_normal_, norm_type=None, **kwargs):
init= img_size
imsize = model_sizes(encoder, size=imsize)
sizes = list(reversed(_get_sz_change_idxs(sizes)))
sz_chg_idxs self.sfs = hook_outputs([encoder[i] for i in sz_chg_idxs], detach=False)
= dummy_eval(encoder, imsize).detach()
x
= sizes[-1][1]
ni = nn.Sequential(ConvLayer(ni, ni*2, act_cls=act_cls, norm_type=norm_type, **kwargs),
middle_conv *2, ni, act_cls=act_cls, norm_type=norm_type, **kwargs)).eval()
ConvLayer(ni= middle_conv(x)
x = [encoder, BatchNorm(ni), nn.ReLU(), middle_conv]
layers
for i,idx in enumerate(sz_chg_idxs):
= i!=len(sz_chg_idxs)-1
not_final = int(x.shape[1]), int(sizes[idx][1])
up_in_c, x_in_c = blur and (not_final or blur_final)
do_blur = self_attention and (i==len(sz_chg_idxs)-3)
sa = UnetBlock(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=do_blur, self_attention=sa,
unet_block =act_cls, init=init, norm_type=norm_type, **kwargs).eval()
act_cls
layers.append(unet_block)= unet_block(x)
x
= x.shape[1]
ni if imsize != sizes[0][-2:]: layers.append(PixelShuffle_ICNR(ni, act_cls=act_cls, norm_type=norm_type))
layers.append(ResizeToOrig())if last_cross:
=True))
layers.append(MergeLayer(dense+= in_channels(encoder)
ni 1, ni, ni//2 if bottle else ni, act_cls=act_cls, norm_type=norm_type, **kwargs))
layers.append(ResBlock(+= [ConvLayer(ni, n_out, ks=1, act_cls=None, norm_type=norm_type, **kwargs)]
layers 3], layers[-2]), init)
apply_init(nn.Sequential(layers[#应用初始化(nn.Sequential(layers[2]), init)
if y_range is not None: layers.append(SigmoidRange(*y_range))
layers.append(ToTensorBase())super().__init__(*layers)
def __del__(self):
if hasattr(self, "sfs"): self.sfs.remove()
:::
from fastai.vision.models import resnet34
= resnet34()
m = nn.Sequential(*list(m.children())[:-2])
m = DynamicUnet(m, 5, (128,128), norm_type=None)
tst = cast(torch.randn(2, 3, 128, 128), TensorImage)
x = tst(x)
y 2, 5, 128, 128]) test_eq(y.shape, [
= DynamicUnet(m, 5, (128,128), norm_type=None)
tst = torch.randn(2, 3, 127, 128)
x = tst(x) y
导出 -
from nbdev import *
nbdev_export()
Converted 00_torch_core.ipynb.
Converted 01_layers.ipynb.
Converted 01a_losses.ipynb.
Converted 02_data.load.ipynb.
Converted 03_data.core.ipynb.
Converted 04_data.external.ipynb.
Converted 05_data.transforms.ipynb.
Converted 06_data.block.ipynb.
Converted 07_vision.core.ipynb.
Converted 08_vision.data.ipynb.
Converted 09_vision.augment.ipynb.
Converted 09b_vision.utils.ipynb.
Converted 09c_vision.widgets.ipynb.
Converted 10_tutorial.pets.ipynb.
Converted 10b_tutorial.albumentations.ipynb.
Converted 11_vision.models.xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_callback.core.ipynb.
Converted 13a_learner.ipynb.
Converted 13b_metrics.ipynb.
Converted 14_callback.schedule.ipynb.
Converted 14a_callback.data.ipynb.
Converted 15_callback.hook.ipynb.
Converted 15a_vision.models.unet.ipynb.
Converted 16_callback.progress.ipynb.
Converted 17_callback.tracker.ipynb.
Converted 18_callback.fp16.ipynb.
Converted 18a_callback.training.ipynb.
Converted 18b_callback.preds.ipynb.
Converted 19_callback.mixup.ipynb.
Converted 20_interpret.ipynb.
Converted 20a_distributed.ipynb.
Converted 21_vision.learner.ipynb.
Converted 22_tutorial.imagenette.ipynb.
Converted 23_tutorial.vision.ipynb.
Converted 24_tutorial.image_sequence.ipynb.
Converted 24_tutorial.siamese.ipynb.
Converted 24_vision.gan.ipynb.
Converted 30_text.core.ipynb.
Converted 31_text.data.ipynb.
Converted 32_text.models.awdlstm.ipynb.
Converted 33_text.models.core.ipynb.
Converted 34_callback.rnn.ipynb.
Converted 35_tutorial.wikitext.ipynb.
Converted 36_text.models.qrnn.ipynb.
Converted 37_text.learner.ipynb.
Converted 38_tutorial.text.ipynb.
Converted 39_tutorial.transformers.ipynb.
Converted 40_tabular.core.ipynb.
Converted 41_tabular.data.ipynb.
Converted 42_tabular.model.ipynb.
Converted 43_tabular.learner.ipynb.
Converted 44_tutorial.tabular.ipynb.
Converted 45_collab.ipynb.
Converted 46_tutorial.collab.ipynb.
Converted 50_tutorial.datablock.ipynb.
Converted 60_medical.imaging.ipynb.
Converted 61_tutorial.medical_imaging.ipynb.
Converted 65_medical.text.ipynb.
Converted 70_callback.wandb.ipynb.
Converted 71_callback.tensorboard.ipynb.
Converted 72_callback.neptune.ipynb.
Converted 73_callback.captum.ipynb.
Converted 74_callback.azureml.ipynb.
Converted 97_test_utils.ipynb.
Converted 99_pytorch_doc.ipynb.
Converted dev-setup.ipynb.
Converted index.ipynb.
Converted quick_start.ipynb.
Converted tutorial.ipynb.