注意
点击这里下载完整的示例代码
(测试版) PyTorch中的Channels Last内存格式
创建于:2020年4月20日 | 最后更新:2023年10月4日 | 最后验证:2024年11月5日
作者: Vitaly Fedyunin
什么是Channels Last
Channels last 内存格式是一种在内存中保留维度顺序的 NCHW 张量的替代排序方式。Channels last 张量以这种方式排序,使得通道成为最密集的维度(即逐像素存储图像)。
例如,NCHW张量的经典(连续)存储(在我们的例子中,这是两个4x4图像,具有3个颜色通道)看起来像这样:

Channels last 内存格式以不同的方式排序数据:

Pytorch 通过利用现有的步幅结构支持内存格式(并提供与现有模型的向后兼容性,包括 eager、JIT 和 TorchScript)。 例如,10x3x16x16 批次在 Channels last 格式中将具有等于 (768, 1, 48, 3) 的步幅。
Channels last 内存格式仅针对4D NCHW张量实现。
内存格式API
以下是如何在连续内存格式和通道最后内存格式之间转换张量。
经典的 PyTorch 连续张量
import torch
N, C, H, W = 10, 3, 32, 32
x = torch.empty(N, C, H, W)
print(x.stride()) # Outputs: (3072, 1024, 32, 1)
(3072, 1024, 32, 1)
转换运算符
x = x.to(memory_format=torch.channels_last)
print(x.shape) # Outputs: (10, 3, 32, 32) as dimensions order preserved
print(x.stride()) # Outputs: (3072, 1, 96, 3)
torch.Size([10, 3, 32, 32])
(3072, 1, 96, 3)
返回连续
x = x.to(memory_format=torch.contiguous_format)
print(x.stride()) # Outputs: (3072, 1024, 32, 1)
(3072, 1024, 32, 1)
替代选项
x = x.contiguous(memory_format=torch.channels_last)
print(x.stride()) # Outputs: (3072, 1, 96, 3)
(3072, 1, 96, 3)
格式检查
print(x.is_contiguous(memory_format=torch.channels_last)) # Outputs: True
True
这两个API to
和 contiguous
之间存在一些细微差别。我们建议在显式转换张量的内存格式时坚持使用 to
。
在一般情况下,这两个API的行为是相同的。然而,在特殊情况下,对于一个大小为NCHW
的4D张量,当C==1
或H==1 && W==1
时,只有to
会生成一个适当的步幅来表示通道最后的内存格式。
这是因为在上述两种情况下,张量的内存格式是模糊的,即具有大小N1HW
的连续张量在内存存储中既是contiguous
的,也是通道最后的。因此,对于给定的内存格式,它们已经被认为是is_contiguous
的,因此contiguous
调用将变为无操作,并且不会更新步幅。相反,to
会重新调整张量的步幅,以便在尺寸为1的维度上具有有意义的步幅,以正确表示预期的内存格式。
special_x = torch.empty(4, 1, 4, 4)
print(special_x.is_contiguous(memory_format=torch.channels_last)) # Outputs: True
print(special_x.is_contiguous(memory_format=torch.contiguous_format)) # Outputs: True
True
True
同样适用于显式排列API permute
。在可能发生歧义的特殊情况下,permute
不保证生成能够正确携带预期内存格式的步幅。我们建议使用带有显式内存格式的to
以避免意外行为。
另外需要注意的是,在极端情况下,当三个非批处理维度都等于1
(C==1 && H==1 && W==1
)时,当前的实现无法将张量标记为通道最后的内存格式。
创建为最后的频道
x = torch.empty(N, C, H, W, memory_format=torch.channels_last)
print(x.stride()) # Outputs: (3072, 1, 96, 3)
(3072, 1, 96, 3)
clone
保留内存格式
(3072, 1, 96, 3)
to
, cuda
, float
… 保留内存格式
if torch.cuda.is_available():
y = x.cuda()
print(y.stride()) # Outputs: (3072, 1, 96, 3)
(3072, 1, 96, 3)
empty_like
, *_like
操作符保留内存格式
y = torch.empty_like(x)
print(y.stride()) # Outputs: (3072, 1, 96, 3)
(3072, 1, 96, 3)
逐点运算符保留内存格式
(3072, 1, 96, 3)
Conv
, Batchnorm
模块使用 cudnn
后端支持通道最后(仅适用于 cuDNN >= 7.6)。与二元逐点操作符不同,卷积模块的通道最后是主要的内存格式。如果所有输入都是连续的内存格式,操作符将生成连续内存格式的输出。否则,输出将是通道最后的内存格式。
if torch.backends.cudnn.is_available() and torch.backends.cudnn.version() >= 7603:
model = torch.nn.Conv2d(8, 4, 3).cuda().half()
model = model.to(memory_format=torch.channels_last) # Module parameters need to be channels last
input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, requires_grad=True)
input = input.to(device="cuda", memory_format=torch.channels_last, dtype=torch.float16)
out = model(input)
print(out.is_contiguous(memory_format=torch.channels_last)) # Outputs: True
True
当输入张量到达一个不支持通道最后顺序的运算符时,内核中应自动应用置换以恢复输入张量的连续性。这会引入开销并停止通道最后内存格式的传播。尽管如此,它保证了正确的输出。
性能提升
通道最后内存格式优化在GPU和CPU上均可使用。
在GPU上,最显著的性能提升是在支持Tensor Cores的NVIDIA硬件上,以降低精度运行
(torch.float16
)。
与连续格式相比,我们能够在利用‘AMP(自动混合精度)’训练脚本时,通过通道最后格式实现超过22%的性能提升。
我们的脚本使用了NVIDIA提供的AMP
https://github.com/NVIDIA/apex。
python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2 ./data
# opt_level = O2
# keep_batchnorm_fp32 = None <class 'NoneType'>
# loss_scale = None <class 'NoneType'>
# CUDNN VERSION: 7603
# => creating model 'resnet50'
# Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights.
# Defaults for this optimization level are:
# enabled : True
# opt_level : O2
# cast_model_type : torch.float16
# patch_torch_functions : False
# keep_batchnorm_fp32 : True
# master_weights : True
# loss_scale : dynamic
# Processing user overrides (additional kwargs that are not None)...
# After processing overrides, optimization options are:
# enabled : True
# opt_level : O2
# cast_model_type : torch.float16
# patch_torch_functions : False
# keep_batchnorm_fp32 : True
# master_weights : True
# loss_scale : dynamic
# Epoch: [0][10/125] Time 0.866 (0.866) Speed 230.949 (230.949) Loss 0.6735125184 (0.6735) Prec@1 61.000 (61.000) Prec@5 100.000 (100.000)
# Epoch: [0][20/125] Time 0.259 (0.562) Speed 773.481 (355.693) Loss 0.6968704462 (0.6852) Prec@1 55.000 (58.000) Prec@5 100.000 (100.000)
# Epoch: [0][30/125] Time 0.258 (0.461) Speed 775.089 (433.965) Loss 0.7877287269 (0.7194) Prec@1 51.500 (55.833) Prec@5 100.000 (100.000)
# Epoch: [0][40/125] Time 0.259 (0.410) Speed 771.710 (487.281) Loss 0.8285319805 (0.7467) Prec@1 48.500 (54.000) Prec@5 100.000 (100.000)
# Epoch: [0][50/125] Time 0.260 (0.380) Speed 770.090 (525.908) Loss 0.7370464802 (0.7447) Prec@1 56.500 (54.500) Prec@5 100.000 (100.000)
# Epoch: [0][60/125] Time 0.258 (0.360) Speed 775.623 (555.728) Loss 0.7592862844 (0.7472) Prec@1 51.000 (53.917) Prec@5 100.000 (100.000)
# Epoch: [0][70/125] Time 0.258 (0.345) Speed 774.746 (579.115) Loss 1.9698858261 (0.9218) Prec@1 49.500 (53.286) Prec@5 100.000 (100.000)
# Epoch: [0][80/125] Time 0.260 (0.335) Speed 770.324 (597.659) Loss 2.2505953312 (1.0879) Prec@1 50.500 (52.938) Prec@5 100.000 (100.000)
传递 --channels-last true
允许以Channels last格式运行模型,观察到22%的性能提升。
python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2 --channels-last true ./data
# opt_level = O2
# keep_batchnorm_fp32 = None <class 'NoneType'>
# loss_scale = None <class 'NoneType'>
#
# CUDNN VERSION: 7603
#
# => creating model 'resnet50'
# Selected optimization level O2: FP16 training with FP32 batchnorm and FP32 master weights.
#
# Defaults for this optimization level are:
# enabled : True
# opt_level : O2
# cast_model_type : torch.float16
# patch_torch_functions : False
# keep_batchnorm_fp32 : True
# master_weights : True
# loss_scale : dynamic
# Processing user overrides (additional kwargs that are not None)...
# After processing overrides, optimization options are:
# enabled : True
# opt_level : O2
# cast_model_type : torch.float16
# patch_torch_functions : False
# keep_batchnorm_fp32 : True
# master_weights : True
# loss_scale : dynamic
#
# Epoch: [0][10/125] Time 0.767 (0.767) Speed 260.785 (260.785) Loss 0.7579724789 (0.7580) Prec@1 53.500 (53.500) Prec@5 100.000 (100.000)
# Epoch: [0][20/125] Time 0.198 (0.482) Speed 1012.135 (414.716) Loss 0.7007197738 (0.7293) Prec@1 49.000 (51.250) Prec@5 100.000 (100.000)
# Epoch: [0][30/125] Time 0.198 (0.387) Speed 1010.977 (516.198) Loss 0.7113101482 (0.7233) Prec@1 55.500 (52.667) Prec@5 100.000 (100.000)
# Epoch: [0][40/125] Time 0.197 (0.340) Speed 1013.023 (588.333) Loss 0.8943189979 (0.7661) Prec@1 54.000 (53.000) Prec@5 100.000 (100.000)
# Epoch: [0][50/125] Time 0.198 (0.312) Speed 1010.541 (641.977) Loss 1.7113249302 (0.9551) Prec@1 51.000 (52.600) Prec@5 100.000 (100.000)
# Epoch: [0][60/125] Time 0.198 (0.293) Speed 1011.163 (683.574) Loss 5.8537774086 (1.7716) Prec@1 50.500 (52.250) Prec@5 100.000 (100.000)
# Epoch: [0][70/125] Time 0.198 (0.279) Speed 1011.453 (716.767) Loss 5.7595844269 (2.3413) Prec@1 46.500 (51.429) Prec@5 100.000 (100.000)
# Epoch: [0][80/125] Time 0.198 (0.269) Speed 1011.827 (743.883) Loss 2.8196096420 (2.4011) Prec@1 47.500 (50.938) Prec@5 100.000 (100.000)
The following list of models has the full support of Channels last and showing 8%-35% performance gains on Volta devices:
alexnet
, mnasnet0_5
, mnasnet0_75
, mnasnet1_0
, mnasnet1_3
, mobilenet_v2
, resnet101
, resnet152
, resnet18
, resnet34
, resnet50
, resnext50_32x4d
, shufflenet_v2_x0_5
, shufflenet_v2_x1_0
, shufflenet_v2_x1_5
, shufflenet_v2_x2_0
, squeezenet1_0
, squeezenet1_1
, vgg11
, vgg11_bn
, vgg13
, vgg13_bn
, vgg16
, vgg16_bn
, vgg19
, vgg19_bn
, wide_resnet101_2
, wide_resnet50_2
The following list of models has the full support of Channels last and showing 26%-76% performance gains on Intel(R) Xeon(R) Ice Lake (or newer) CPUs:
alexnet
, densenet121
, densenet161
, densenet169
, googlenet
, inception_v3
, mnasnet0_5
, mnasnet1_0
, resnet101
, resnet152
, resnet18
, resnet34
, resnet50
, resnext101_32x8d
, resnext50_32x4d
, shufflenet_v2_x0_5
, shufflenet_v2_x1_0
, squeezenet1_0
, squeezenet1_1
, vgg11
, vgg11_bn
, vgg13
, vgg13_bn
, vgg16
, vgg16_bn
, vgg19
, vgg19_bn
, wide_resnet101_2
, wide_resnet50_2
转换现有模型
通道最后支持不受现有模型的限制,因为任何模型都可以转换为通道最后,并在输入(或某些权重)正确格式化后通过图传播格式。
# Need to be done once, after model initialization (or load)
model = model.to(memory_format=torch.channels_last) # Replace with your model
# Need to be done for every input
input = input.to(memory_format=torch.channels_last) # Replace with your input
output = model(input)
然而,并非所有运算符都完全转换为支持通道最后(通常返回连续输出)。在上面发布的示例中,不支持通道最后的层将停止内存格式传播。尽管如此,由于我们已经将模型转换为通道最后格式,这意味着每个卷积层(其4维权重在通道最后内存格式中)将恢复通道最后内存格式,并从更快的内核中受益。
但是不支持通道最后的操作符确实会通过排列引入开销。可选地,如果您想提高转换模型的性能,可以调查并识别模型中不支持通道最后的操作符。
这意味着你需要根据支持的运算符列表验证使用的运算符列表https://github.com/pytorch/pytorch/wiki/Operators-with-Channels-Last-support,或者在急切执行模式中引入内存格式检查并运行你的模型。
运行以下代码后,如果操作符的输出与输入的内存格式不匹配,操作符将引发异常。
def contains_cl(args):
for t in args:
if isinstance(t, torch.Tensor):
if t.is_contiguous(memory_format=torch.channels_last) and not t.is_contiguous():
return True
elif isinstance(t, list) or isinstance(t, tuple):
if contains_cl(list(t)):
return True
return False
def print_inputs(args, indent=""):
for t in args:
if isinstance(t, torch.Tensor):
print(indent, t.stride(), t.shape, t.device, t.dtype)
elif isinstance(t, list) or isinstance(t, tuple):
print(indent, type(t))
print_inputs(list(t), indent=indent + " ")
else:
print(indent, t)
def check_wrapper(fn):
name = fn.__name__
def check_cl(*args, **kwargs):
was_cl = contains_cl(args)
try:
result = fn(*args, **kwargs)
except Exception as e:
print("`{}` inputs are:".format(name))
print_inputs(args)
print("-------------------")
raise e
failed = False
if was_cl:
if isinstance(result, torch.Tensor):
if result.dim() == 4 and not result.is_contiguous(memory_format=torch.channels_last):
print(
"`{}` got channels_last input, but output is not channels_last:".format(name),
result.shape,
result.stride(),
result.device,
result.dtype,
)
failed = True
if failed and True:
print("`{}` inputs are:".format(name))
print_inputs(args)
raise Exception("Operator `{}` lost channels_last property".format(name))
return result
return check_cl
old_attrs = dict()
def attribute(m):
old_attrs[m] = dict()
for i in dir(m):
e = getattr(m, i)
exclude_functions = ["is_cuda", "has_names", "numel", "stride", "Tensor", "is_contiguous", "__class__"]
if i not in exclude_functions and not i.startswith("_") and "__call__" in dir(e):
try:
old_attrs[m][i] = e
setattr(m, i, check_wrapper(e))
except Exception as e:
print(i)
print(e)
attribute(torch.Tensor)
attribute(torch.nn.functional)
attribute(torch)
如果您发现了一个不支持通道最后张量的操作符,并且您想要贡献代码,请随意使用以下开发者指南 https://github.com/pytorch/pytorch/wiki/Writing-memory-format-aware-operators。
下面的代码用于恢复torch的属性。
for (m, attrs) in old_attrs.items():
for (k, v) in attrs.items():
setattr(m, k, v)
待办工作
还有很多事情要做,例如:
解决
N1HW
和NC11
张量的歧义;分布式训练支持的测试;
提高操作员的覆盖率。
如果您有反馈和/或改进建议,请通过创建问题告知我们。
脚本总运行时间: ( 0 分钟 0.043 秒)