通道最后训练

! [ -e /content ] && pip install -Uqq fastai  # 在Colab上升级fastai
from __future__ import annotations
from fastai.basics import *
from fastai.callback.fp16 import AMPMode, MixedPrecision

from torch.cuda.amp import GradScaler
from fastai.test_utils import *
from nbdev.showdoc import *

使用通道最后格式更快地训练模型(测试版)

使用MixedPrecision,在Tensor Cores上以通道最后格式训练的图像模型可以提高训练吞吐量,相比于连续格式。PyTorch观察到在使用通道最后格式进行ResNet50训练时速度提高了22%,在V100上测试的多个模型中提高了8-35%。

通道最后格式与现代GPU(Volta、Turing或更新版本)和现代CPU(Ice Lake或更新版本)兼容。

当前,通道最后内存格式已针对NCHW张量实现。并非所有的PyTorch操作都已转换为支持通道最后格式。有关更多细节,请参见(Beta) PyTorch中的通道最后内存格式教程。

ChannelsLast -

class ChannelsLast(Callback):
    "Channels last training using PyTorch's Channels Last Memory Format (beta)"
    order = -1 # 需要在任何模型修改回调发生之前运行
    def before_fit(self):
        self.learn.model.to(memory_format=torch.channels_last)

当PyTorch模型设置为通道最后格式时,PyTorch会自动将任何兼容的NCHW输入张量转换为NHWC格式。ChannelsLast将模型设置为通道最后格式,因此不需要更改数据加载器或输入。

Note

ChannelsLast应该适用于大多数卷积timm模型。

然而,建议测试每个模型,因为不同的PyTorch版本支持的操作有所不同。

在不支持的PyTorch操作中使用ChannelsLast可能导致“通道抖动”,即通道最后的输入在不支持的PyTorch操作中被转换为连续格式,然后在张量核心上执行时又转换回通道最后格式,当返回到操作时又回到连续格式,最后再转换为通道最后格式以供下一个层使用。模型中如果有过多不支持的操作可能会导致性能下降。

@patch
@delegates(GradScaler)
def to_channelslast(self:Learner,
    use_amp:bool=True, # 添加 `MixedPrecision` 并设置 `amp_mode`。推荐用于实现全通道的最后性能。
    amp_mode:str|AMPMode=AMPMode.FP16, # 混合精度训练模式。支持fp16和bf16。
    **kwargs
):
    "Set `Learner` and inputs to `channels_last` format and float16 Mixed Precision by default"
    if use_amp and not hasattr(self, 'mixed_precision') and not hasattr(self, 'channels_last'):
        return self.add_cbs([ChannelsLast(), MixedPrecision(amp_mode, **kwargs)])
    elif not hasattr(self, 'channels_last'):
        return self.add_cb(ChannelsLast())
@patch
def to_contiguous(self:Learner, to_fp32:bool=False):
    "Set `Learner` and inputs to `contiguous_format` (default format), optionally to single precision"
    self.model.to(memory_format=torch.contiguous_format)
    if to_fp32:
        return self.remove_cbs([ChannelsLast, MixedPrecision])
    else:
        return self.remove_cb(ChannelsLast)

测试渠道最后 -

from torch.utils.data import TensorDataset
class ChannelsLastTest(Callback):
    "Asserts that predictions are in channels last format"
    order = MixedPrecision.order-1
    def after_pred(self):
        assert self.pred.is_contiguous(memory_format=torch.channels_last), "Model and/or output isn't channels last"
#|cuda
def synth_dbunch(bs=16, n_train=10, n_valid=2, cuda=True):
    def get_data(n):
        return TensorDataset(TensorImage(torch.randn(bs*n, 3, 32, 32)))
    train_ds = get_data(n_train)
    valid_ds = get_data(n_valid)
    device = default_device() if cuda else None
    train_dl = TfmdDL(train_ds, bs=bs, shuffle=True, num_workers=0)
    valid_dl = TfmdDL(valid_ds, bs=bs, num_workers=0)
    return DataLoaders(train_dl, valid_dl, device=device)
隐藏
#|cuda
# 测试必须在现代硬件(Volta、Turning或更新版本)上进行。
with no_random():
    learn = synth_learner(cbs=[MixedPrecision,ChannelsLast,ChannelsLastTest], cuda=True, data=synth_dbunch())
    class ConvModel(Module):
        def __init__(self): self.conv = nn.Conv2d(3, 32, 1)
        def forward(self,x): return self.conv(x)
    def fakeloss(): pass
    learn.model = ConvModel()
    learn.opt_func = partial(SGD, mom=0.)
    learn.loss_func=fakeloss
    learn.fit(3)
epoch train_loss valid_loss time
0 nan None 00:01
1 nan None 00:00
2 nan None 00:00

导出 -

from nbdev import *
nbdev_export()