注意
点击 这里 下载完整的示例代码
如何通过将优化器步骤融合到反向传递中来节省内存
创建于:2023年10月02日 | 最后更新:2024年1月16日 | 最后验证:2024年11月05日
你好!本教程旨在展示一种通过减少梯度占用的内存来减少训练循环内存占用的方法。假设你有一个模型,并且你对优化内存以避免内存不足
(OOM)错误或仅仅是为了更好地利用你的GPU感兴趣。那么,你可能会很幸运(如果梯度占用了你的一部分内存并且你不需要进行梯度累积)。我们将探讨以下内容:
在你的训练或微调循环中,什么占用了内存,
如何捕获和可视化内存快照以确定瓶颈,
新的
Tensor.register_post_accumulate_grad_hook(hook)
API,最后,如何在10行代码中实现内存节省。
要运行本教程,您需要:
PyTorch 2.1.0 或更新版本,带有
torchvision
1 CUDA GPU 如果您想在本地运行内存可视化。 否则,此技术在任何设备上都会同样受益。
让我们从导入所需的模块和模型开始。我们将使用来自torchvision的视觉变换器模型,但请随意替换为您自己的模型。我们还将使用torch.optim.Adam
作为我们的优化器,但同样,请随意替换为您自己的优化器。
import torch
from torchvision import models
from pickle import dump
model = models.vit_l_16(weights='DEFAULT').cuda()
optimizer = torch.optim.Adam(model.parameters())
Downloading: "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/vit_l_16-852ce7e3.pth
0%| | 0.00/1.13G [00:00<?, ?B/s]
1%|1 | 16.5M/1.13G [00:00<00:10, 115MB/s]
3%|2 | 32.9M/1.13G [00:00<00:09, 127MB/s]
4%|4 | 49.2M/1.13G [00:00<00:11, 106MB/s]
6%|5 | 65.6M/1.13G [00:00<00:09, 116MB/s]
7%|7 | 81.4M/1.13G [00:00<00:08, 129MB/s]
8%|8 | 94.4M/1.13G [00:00<00:10, 107MB/s]
9%|9 | 106M/1.13G [00:00<00:10, 108MB/s]
10%|# | 116M/1.13G [00:01<00:11, 97.5MB/s]
11%|#1 | 131M/1.13G [00:01<00:12, 86.0MB/s]
13%|#2 | 148M/1.13G [00:01<00:10, 101MB/s]
14%|#4 | 164M/1.13G [00:01<00:10, 99.8MB/s]
16%|#5 | 180M/1.13G [00:01<00:09, 103MB/s]
17%|#6 | 196M/1.13G [00:01<00:08, 116MB/s]
18%|#7 | 208M/1.13G [00:02<00:09, 107MB/s]
19%|#8 | 218M/1.13G [00:02<00:09, 100MB/s]
20%|#9 | 229M/1.13G [00:02<00:11, 85.9MB/s]
21%|##1 | 246M/1.13G [00:02<00:09, 97.1MB/s]
23%|##2 | 262M/1.13G [00:02<00:09, 101MB/s]
24%|##3 | 278M/1.13G [00:02<00:08, 114MB/s]
25%|##4 | 289M/1.13G [00:02<00:09, 101MB/s]
26%|##5 | 300M/1.13G [00:03<00:09, 99.4MB/s]
27%|##6 | 311M/1.13G [00:03<00:09, 92.8MB/s]
28%|##8 | 328M/1.13G [00:03<00:08, 98.2MB/s]
29%|##9 | 339M/1.13G [00:03<00:09, 95.4MB/s]
30%|##9 | 348M/1.13G [00:03<00:08, 94.9MB/s]
31%|###1 | 360M/1.13G [00:03<00:08, 94.6MB/s]
32%|###2 | 377M/1.13G [00:03<00:08, 97.1MB/s]
34%|###3 | 393M/1.13G [00:04<00:07, 102MB/s]
35%|###5 | 410M/1.13G [00:04<00:08, 94.0MB/s]
37%|###6 | 426M/1.13G [00:04<00:07, 103MB/s]
38%|###8 | 442M/1.13G [00:04<00:07, 101MB/s]
39%|###9 | 458M/1.13G [00:04<00:06, 115MB/s]
40%|#### | 470M/1.13G [00:04<00:06, 108MB/s]
41%|####1 | 481M/1.13G [00:04<00:07, 94.1MB/s]
42%|####2 | 492M/1.13G [00:05<00:07, 93.7MB/s]
44%|####3 | 508M/1.13G [00:05<00:06, 101MB/s]
45%|####5 | 524M/1.13G [00:05<00:05, 115MB/s]
46%|####6 | 535M/1.13G [00:05<00:06, 107MB/s]
47%|####7 | 546M/1.13G [00:05<00:06, 97.5MB/s]
48%|####7 | 557M/1.13G [00:05<00:06, 93.9MB/s]
49%|####9 | 573M/1.13G [00:05<00:05, 112MB/s]
50%|##### | 585M/1.13G [00:05<00:05, 104MB/s]
51%|#####1 | 595M/1.13G [00:06<00:06, 93.6MB/s]
52%|#####2 | 606M/1.13G [00:06<00:06, 86.3MB/s]
54%|#####3 | 623M/1.13G [00:06<00:06, 89.7MB/s]
55%|#####5 | 639M/1.13G [00:06<00:05, 99.6MB/s]
56%|#####6 | 655M/1.13G [00:06<00:05, 100MB/s]
58%|#####7 | 672M/1.13G [00:06<00:05, 99.2MB/s]
59%|#####9 | 688M/1.13G [00:07<00:04, 102MB/s]
60%|###### | 702M/1.13G [00:07<00:04, 110MB/s]
61%|######1 | 713M/1.13G [00:07<00:04, 106MB/s]
62%|######2 | 723M/1.13G [00:07<00:04, 99.5MB/s]
64%|######3 | 737M/1.13G [00:07<00:04, 93.9MB/s]
65%|######4 | 754M/1.13G [00:07<00:04, 95.7MB/s]
66%|######6 | 770M/1.13G [00:07<00:04, 98.2MB/s]
68%|######7 | 786M/1.13G [00:08<00:03, 101MB/s]
69%|######9 | 802M/1.13G [00:08<00:03, 115MB/s]
70%|####### | 814M/1.13G [00:08<00:03, 99.7MB/s]
71%|####### | 824M/1.13G [00:08<00:03, 101MB/s]
72%|#######1 | 836M/1.13G [00:08<00:03, 92.6MB/s]
73%|#######3 | 852M/1.13G [00:08<00:03, 88.4MB/s]
75%|#######4 | 868M/1.13G [00:09<00:03, 99.6MB/s]
76%|#######6 | 885M/1.13G [00:09<00:03, 86.4MB/s]
78%|#######7 | 901M/1.13G [00:09<00:02, 92.5MB/s]
79%|#######9 | 918M/1.13G [00:09<00:02, 108MB/s]
80%|######## | 934M/1.13G [00:09<00:02, 107MB/s]
82%|########1 | 950M/1.13G [00:09<00:01, 120MB/s]
83%|########2 | 962M/1.13G [00:09<00:01, 108MB/s]
84%|########3 | 974M/1.13G [00:10<00:01, 98.4MB/s]
85%|########4 | 984M/1.13G [00:10<00:02, 85.9MB/s]
85%|########5 | 993M/1.13G [00:10<00:02, 88.0MB/s]
87%|########6 | 0.98G/1.13G [00:10<00:01, 100MB/s]
88%|########7 | 0.99G/1.13G [00:10<00:01, 105MB/s]
89%|########8 | 1.01G/1.13G [00:10<00:01, 102MB/s]
90%|######### | 1.02G/1.13G [00:10<00:01, 103MB/s]
92%|#########1| 1.04G/1.13G [00:11<00:01, 96.4MB/s]
93%|#########3| 1.06G/1.13G [00:11<00:00, 98.6MB/s]
95%|#########4| 1.07G/1.13G [00:11<00:00, 103MB/s]
96%|#########5| 1.09G/1.13G [00:11<00:00, 99.1MB/s]
97%|#########7| 1.10G/1.13G [00:11<00:00, 100MB/s]
99%|#########8| 1.12G/1.13G [00:11<00:00, 98.7MB/s]
100%|#########9| 1.13G/1.13G [00:12<00:00, 92.9MB/s]
100%|##########| 1.13G/1.13G [00:12<00:00, 100MB/s]
现在让我们定义我们的典型训练循环。在训练时应该使用真实图像,但为了本教程的目的,我们传入的是假输入,并不担心加载任何实际数据。
IMAGE_SIZE = 224
def train(model, optimizer):
# create our fake image input: tensor shape is batch_size, channels, height, width
fake_image = torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE).cuda()
# call our forward and backward
loss = model.forward(fake_image)
loss.sum().backward()
# optimizer update
optimizer.step()
optimizer.zero_grad()
训练期间的内存使用情况
我们将要查看一些内存快照,因此我们应该准备好正确地分析它们。通常,训练内存包括:
模型参数(大小 P)
为反向传播保存的激活(大小 A)
梯度,与模型参数大小相同,因此大小 G = P。
优化器状态,其大小与参数的大小成正比。在这种情况下,Adam的状态需要模型参数的2倍,因此大小为O = 2P。
中间张量,在整个计算过程中分配。我们暂时不用担心它们,因为它们通常很小且短暂。
捕获和可视化内存快照
让我们获取一个内存快照!当你的代码运行时,考虑一下你期望的CUDA内存时间线会是什么样子。
# tell CUDA to start recording memory allocations
torch.cuda.memory._record_memory_history(enabled='all')
# train 3 steps
for _ in range(3):
train(model, optimizer)
# save a snapshot of the memory allocations
s = torch.cuda.memory._snapshot()
with open(f"snapshot.pickle", "wb") as f:
dump(s, f)
# tell CUDA to stop recording memory allocations now
torch.cuda.memory._record_memory_history(enabled=None)
现在通过在CUDA内存可视化器中拖放snapshot.pickle
文件来打开快照,网址为https://pytorch.org/memory_viz。内存时间线是否符合您的预期?

在训练步骤之前,模型参数已经加载到内存中,因此我们立即看到一大块内存被用于权重。当我们开始前向传播时,内存逐渐分配给激活值,或者我们保存的张量以便在反向传播中计算梯度。一旦我们开始反向传播,激活值逐渐被释放,而梯度的内存开始增加。
最后,当优化器启动时,其状态将被延迟初始化,因此我们应该只在第一个训练循环的优化器步骤中看到优化器状态内存逐渐增加。在未来的循环中,优化器内存将保持不变并在原地更新。然后,在每个训练循环结束时,当调用zero_grad
时,梯度的内存将被相应地释放。
在这个训练循环中,内存瓶颈在哪里?或者换句话说,内存峰值出现在哪里?
峰值内存使用发生在优化器步骤期间!注意,此时内存由约1.2GB的参数、约1.2GB的梯度和预期的约2.4GB=2*1.2GB的优化器状态组成。最后的约1.2GB来自Adam优化器需要的内存用于中间变量,总计约6GB的峰值内存。技术上,如果你设置Adam(model.parameters(), foreach=False)
,你可以消除最后1.2GB的优化器中间变量的需求,这将用运行时间来换取内存。如果关闭foreach
运行时优化足以为你节省内存,那很好,但如果你好奇本教程如何帮助你做得更好,请继续阅读!通过我们即将介绍的技术,我们将通过消除约1.2GB的梯度内存以及优化器中间变量内存的需求来减少峰值内存。现在,你预计新的峰值内存会是多少?答案将在下一个快照中揭晓。
免责声明:此技术不适用于所有人
在我们过于兴奋之前,我们必须考虑这种技术是否适用于你的用例。这并不是一个万能解决方案!将优化器步骤融合到反向传播中仅针对减少梯度内存(并且作为副作用也减少了优化器中间状态的内存)。因此,梯度占用的内存越大,内存减少的效果就越显著。在我们上面的例子中,梯度占用了内存饼图的20%,这是相当可观的!
这可能不适用于您的情况,例如,如果您的权重已经非常小(例如,由于应用了LoRa),那么梯度在您的训练循环中占用的空间不大,收益也就不那么令人兴奋。在这种情况下,您应该首先尝试其他技术,如激活检查点、分布式训练、量化或减少批量大小。然后,当梯度再次成为瓶颈时,再回到本教程!
还在这里吗?太好了,让我们介绍一下我们新的 register_post_accumulate_grad_hook(hook)
API 在 Tensor 上的应用。
Tensor.register_post_accumulate_grad_hook(hook)
API 和我们的技术
我们的技术依赖于在backward()
期间不需要保存梯度。相反,一旦梯度被累积,我们将立即将优化器应用于相应的参数,并完全丢弃该梯度!这消除了在优化器步骤之前需要保留大量梯度缓冲区的需求。
那么我们如何才能更积极地应用优化器呢?在我们的2.1版本中,我们添加了一个新的API torch.Tensor.register_post_accumulate_grad_hook()
,它允许我们在张量的.grad
字段被累积后添加一个钩子。我们将把优化器步骤封装到这个钩子中。怎么做呢?
如何在10行代码中整合所有内容
还记得我们一开始的模型和优化器设置吗?我会把它们注释掉,这样我们就不会浪费资源重新运行代码。
model = models.vit_l_16(weights='DEFAULT').cuda()
optimizer = torch.optim.Adam(model.parameters())
# Instead of having just *one* optimizer, we will have a ``dict`` of optimizers
# for every parameter so we could reference them in our hook.
optimizer_dict = {p: torch.optim.Adam([p], foreach=False) for p in model.parameters()}
# Define our hook, which will call the optimizer ``step()`` and ``zero_grad()``
def optimizer_hook(parameter) -> None:
optimizer_dict[parameter].step()
optimizer_dict[parameter].zero_grad()
# Register the hook onto every parameter
for p in model.parameters():
p.register_post_accumulate_grad_hook(optimizer_hook)
# Now remember our previous ``train()`` function? Since the optimizer has been
# fused into the backward, we can remove the optimizer step and zero_grad calls.
def train(model):
# create our fake image input: tensor shape is batch_size, channels, height, width
fake_image = torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE).cuda()
# call our forward and backward
loss = model.forward(fake_image)
loss.sum().backward()
# optimizer update --> no longer needed!
# optimizer.step()
# optimizer.zero_grad()
在我们的示例模型中,这大约需要10行代码的更改,这很简洁。 然而,对于实际模型来说,将优化器替换为优化器字典可能是一个相当侵入性的更改,特别是对于那些使用 ``LRScheduler``或在训练周期中操作优化器配置的人来说。与这些更改一起使用这个API将更加复杂,可能需要将更多配置移动到全局状态中,但这并非不可能。也就是说,PyTorch的下一步是使这个API更容易与LRSchedulers和其他你已经习惯的功能一起采用。
但让我回到说服你这种技术是值得的这一点上。我们将咨询我们的朋友,内存快照。
# delete optimizer memory from before to get a clean slate for the next
# memory snapshot
del optimizer
# tell CUDA to start recording memory allocations
torch.cuda.memory._record_memory_history(enabled='all')
# train 3 steps. note that we no longer pass the optimizer into train()
for _ in range(3):
train(model)
# save a snapshot of the memory allocations
s = torch.cuda.memory._snapshot()
with open(f"snapshot-opt-in-bwd.pickle", "wb") as f:
dump(s, f)
# tell CUDA to stop recording memory allocations now
torch.cuda.memory._record_memory_history(enabled=None)
是的,花点时间将你的快照拖到CUDA内存可视化工具中。

- Several major observations:
没有更多的优化器步骤了!没错…我们将其融合到了反向传播中。
同样,向后拖拽的时间更长,并且中间产物有更多的随机分配。这是预期的,因为优化器步骤需要中间产物。
最重要的是!峰值内存更低了!现在大约是4GB(我希望这能接近你之前的预期)。
请注意,与之前相比,不再有为梯度分配的大块内存,节省了约1.2GB的内存。相反,我们通过尽可能提前移动优化器步骤,在计算完每个梯度后迅速释放它们。哇哦!顺便说一下,另外约1.2GB的内存节省来自于将优化器分解为每个参数的优化器,因此中间结果按比例缩小。这个细节比梯度内存节省不那么重要,因为你可以通过仅将foreach=False
而不使用此技术来获得优化器中间结果的节省。
你可能会正确地想知道:如果我们节省了2.4GB的内存,为什么峰值内存不是6GB - 2.4GB = 3.6GB?嗯,峰值已经移动了!峰值现在靠近反向步骤的开始,当我们仍然在内存中有激活时,而之前,峰值是在优化器步骤期间,当激活已经被释放时。因此,~0.4GB的差异(约4.0GB - 约3.6GB)是由于激活内存。然后可以想象,这种技术可以与激活检查点结合使用,以获得更多的内存节省。