! [ -e /content ] && pip install -Uqq fastai # 在Colab上升级fastai
通道最后训练
from __future__ import annotations
from fastai.basics import *
from fastai.callback.fp16 import AMPMode, MixedPrecision
from torch.cuda.amp import GradScaler
from fastai.test_utils import *
from nbdev.showdoc import *
使用通道最后格式更快地训练模型(测试版)
使用MixedPrecision
,在Tensor Cores上以通道最后格式训练的图像模型可以提高训练吞吐量,相比于连续格式。PyTorch观察到在使用通道最后格式进行ResNet50训练时速度提高了22%,在V100上测试的多个模型中提高了8-35%。
通道最后格式与现代GPU(Volta、Turing或更新版本)和现代CPU(Ice Lake或更新版本)兼容。
当前,通道最后内存格式已针对NCHW张量实现。并非所有的PyTorch操作都已转换为支持通道最后格式。有关更多细节,请参见(Beta) PyTorch中的通道最后内存格式教程。
ChannelsLast -
class ChannelsLast(Callback):
"Channels last training using PyTorch's Channels Last Memory Format (beta)"
= -1 # 需要在任何模型修改回调发生之前运行
order def before_fit(self):
self.learn.model.to(memory_format=torch.channels_last)
当PyTorch模型设置为通道最后格式时,PyTorch会自动将任何兼容的NCHW输入张量转换为NHWC格式。ChannelsLast
将模型设置为通道最后格式,因此不需要更改数据加载器或输入。
Note
ChannelsLast
应该适用于大多数卷积timm
模型。
然而,建议测试每个模型,因为不同的PyTorch版本支持的操作有所不同。
在不支持的PyTorch操作中使用ChannelsLast
可能导致“通道抖动”,即通道最后的输入在不支持的PyTorch操作中被转换为连续格式,然后在张量核心上执行时又转换回通道最后格式,当返回到操作时又回到连续格式,最后再转换为通道最后格式以供下一个层使用。模型中如果有过多不支持的操作可能会导致性能下降。
@patch
@delegates(GradScaler)
def to_channelslast(self:Learner,
bool=True, # 添加 `MixedPrecision` 并设置 `amp_mode`。推荐用于实现全通道的最后性能。
use_amp:str|AMPMode=AMPMode.FP16, # 混合精度训练模式。支持fp16和bf16。
amp_mode:**kwargs
):"Set `Learner` and inputs to `channels_last` format and float16 Mixed Precision by default"
if use_amp and not hasattr(self, 'mixed_precision') and not hasattr(self, 'channels_last'):
return self.add_cbs([ChannelsLast(), MixedPrecision(amp_mode, **kwargs)])
elif not hasattr(self, 'channels_last'):
return self.add_cb(ChannelsLast())
@patch
def to_contiguous(self:Learner, to_fp32:bool=False):
"Set `Learner` and inputs to `contiguous_format` (default format), optionally to single precision"
self.model.to(memory_format=torch.contiguous_format)
if to_fp32:
return self.remove_cbs([ChannelsLast, MixedPrecision])
else:
return self.remove_cb(ChannelsLast)
测试渠道最后 -
from torch.utils.data import TensorDataset
class ChannelsLastTest(Callback):
"Asserts that predictions are in channels last format"
= MixedPrecision.order-1
order def after_pred(self):
assert self.pred.is_contiguous(memory_format=torch.channels_last), "Model and/or output isn't channels last"
#|cuda
def synth_dbunch(bs=16, n_train=10, n_valid=2, cuda=True):
def get_data(n):
return TensorDataset(TensorImage(torch.randn(bs*n, 3, 32, 32)))
= get_data(n_train)
train_ds = get_data(n_valid)
valid_ds = default_device() if cuda else None
device = TfmdDL(train_ds, bs=bs, shuffle=True, num_workers=0)
train_dl = TfmdDL(valid_ds, bs=bs, num_workers=0)
valid_dl return DataLoaders(train_dl, valid_dl, device=device)
隐藏#|cuda
# 测试必须在现代硬件(Volta、Turning或更新版本)上进行。
with no_random():
= synth_learner(cbs=[MixedPrecision,ChannelsLast,ChannelsLastTest], cuda=True, data=synth_dbunch())
learn class ConvModel(Module):
def __init__(self): self.conv = nn.Conv2d(3, 32, 1)
def forward(self,x): return self.conv(x)
def fakeloss(): pass
= ConvModel()
learn.model = partial(SGD, mom=0.)
learn.opt_func =fakeloss
learn.loss_func3) learn.fit(
epoch | train_loss | valid_loss | time |
---|---|---|---|
0 | nan | None | 00:01 |
1 | nan | None | 00:00 |
2 | nan | None | 00:00 |
导出 -
from nbdev import *
nbdev_export()