• Tutorials >
  • Compiled Autograd: Capturing a larger backward graph for torch.compile
Shortcuts

编译自动求导:为torch.compile捕获更大的反向图

创建于:2024年10月09日 | 最后更新:2024年10月23日 | 最后验证:2024年10月09日

作者: Simon Fan

What you will learn
  • 编译的自动梯度如何与torch.compile交互

  • 如何使用编译的自动梯度API

  • 如何使用TORCH_LOGS检查日志

Prerequisites

概述

编译自动梯度是PyTorch 2.4中引入的一个torch.compile扩展,它允许捕获更大的反向图。

虽然torch.compile确实捕获了反向图,但它只是部分捕获。AOTAutograd组件提前捕获反向图,但存在一些限制:

  • 前向传播中的图断裂会导致后向传播中的图断裂

  • Backward hooks 未被捕获

编译自动梯度通过直接与自动梯度引擎集成来解决这些限制,使其能够在运行时捕获完整的反向图。具有这两个特征的模型应该尝试编译自动梯度,并可能观察到更好的性能。

然而,Compiled Autograd 引入了自身的限制:

  • 在反向传播开始时增加了缓存查找的运行时开销

  • 由于捕获范围较大,在dynamo中更容易发生重新编译和图中断

注意

编译的自动梯度功能正在积极开发中,尚未与所有现有的PyTorch功能兼容。有关特定功能的最新状态,请参阅编译的自动梯度登陆页面

设置

在本教程中,我们将基于这个简单的神经网络模型进行示例。 它接收一个10维的输入向量,通过一个单一的线性层进行处理,并输出另一个10维的向量。

import torch

class Model(torch.nn.Module):
   def __init__(self):
      super().__init__()
      self.linear = torch.nn.Linear(10, 10)

   def forward(self, x):
      return self.linear(x)

基本用法

在调用torch.compile API之前,请确保将torch._dynamo.config.compiled_autograd设置为True

model = Model()
x = torch.randn(10)

torch._dynamo.config.compiled_autograd = True
@torch.compile
def train(model, x):
   loss = model(x).sum()
   loss.backward()

train(model, x)

在上面的代码中,我们创建了一个Model类的实例,并通过使用torch.randn(10)生成了一个随机的10维张量x。 我们定义了训练循环函数train,并用@torch.compile装饰它以优化其执行。 当调用train(model, x)时:

  • Python 解释器调用 Dynamo,因为此调用被装饰为 @torch.compile

  • Dynamo拦截Python字节码,模拟它们的执行并将操作记录到图中。

  • AOTDispatcher 禁用钩子并调用自动求导引擎来计算 model.linear.weightmodel.linear.bias 的梯度,并将操作记录到图中。使用 torch.autograd.Function,AOTDispatcher 重写了 train 的前向和后向实现。

  • Inductor 生成一个函数,对应于 AOTDispatcher 前向和后向的优化实现。

  • Dynamo 设置将由 Python 解释器评估的优化函数。

  • Python 解释器执行优化后的函数,该函数执行 loss = model(x).sum()

  • Python 解释器执行 loss.backward(),调用自动求导引擎,由于我们设置了 torch._dynamo.config.compiled_autograd = True,因此路由到编译后的自动求导引擎。

  • 编译后的自动梯度计算为model.linear.weightmodel.linear.bias计算梯度,并将操作记录到图中,包括它遇到的任何钩子。在此过程中,它将记录之前由AOTDispatcher重写的反向传播。编译后的自动梯度然后生成一个新函数,该函数对应于loss.backward()的完全跟踪实现,并在推理模式下使用torch.compile执行它。

  • 相同的步骤递归地应用于编译后的自动梯度图,但这次AOTDispatcher将不需要对图进行分区。

检查编译的自动梯度日志

使用TORCH_LOGS环境变量运行脚本:

  • 要仅打印编译的自动梯度图,请使用 TORCH_LOGS="compiled_autograd" python example.py

  • 为了以性能为代价打印包含更多张量元数据和重新编译原因的图表,请使用 TORCH_LOGS="compiled_autograd_verbose" python example.py

重新运行上面的代码片段,编译后的自动梯度图现在应该会记录到stderr。某些图节点会有以aot0_为前缀的名称,这些节点对应于之前在AOTAutograd反向图0中提前编译的节点,例如,aot0_view_2对应于id=0的AOT反向图中的view_2

在下图中,红色框封装了由torch.compile捕获的AOT反向图,而没有使用Compiled Autograd。

../_images/entire_verbose_log.png

注意

这是我们将调用torch.compile的图,不是优化后的图。编译后的Autograd本质上生成了一些未优化的Python代码来表示整个C++ autograd执行。

使用不同的标志编译前向和后向传递

您可以为两次编译使用不同的编译器配置,例如,即使在前向中存在图中断,后向也可能是一个完整图。

def train(model, x):
    model = torch.compile(model)
    loss = model(x).sum()
    torch._dynamo.config.compiled_autograd = True
    torch.compile(lambda: loss.backward(), fullgraph=True)()

或者你可以使用上下文管理器,它将适用于其范围内的所有自动求导调用。

def train(model, x):
   model = torch.compile(model)
   loss = model(x).sum()
   with torch._dynamo.compiled_autograd.enable(torch.compile(fullgraph=True)):
      loss.backward()

编译的自动梯度解决了AOTAutograd的某些限制

  1. 在前向传递中的图中断不再必然导致后向传递中的图中断:

@torch.compile(backend="aot_eager")
def fn(x):
   # 1st graph
   temp = x + 10
   torch._dynamo.graph_break()
   # 2nd graph
   temp = temp + 10
   torch._dynamo.graph_break()
   # 3rd graph
   return temp.sum()

x = torch.randn(10, 10, requires_grad=True)
torch._dynamo.utils.counters.clear()
loss = fn(x)

# 1. base torch.compile
loss.backward(retain_graph=True)
assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 3)
torch._dynamo.utils.counters.clear()

# 2. torch.compile with compiled autograd
with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
   loss.backward()

# single graph for the backward
assert(torch._dynamo.utils.counters["stats"]["unique_graphs"] == 1)

在第一个torch.compile案例中,我们看到由于编译函数fn中的2个图中断,生成了3个反向图。 而在第二个torch.compile与编译自动求导的案例中,我们看到尽管有图中断,仍然追踪了一个完整的反向图。

注意

当追踪由Compiled Autograd捕获的反向钩子时,Dynamo仍然可能出现图形中断的情况。

  1. 现在可以捕获反向钩子了

@torch.compile(backend="aot_eager")
def fn(x):
   return x.sum()

x = torch.randn(10, 10, requires_grad=True)
x.register_hook(lambda grad: grad+10)
loss = fn(x)

with torch._dynamo.compiled_autograd.enable(torch.compile(backend="aot_eager")):
   loss.backward()

图中应该有一个call_hook节点,dynamo稍后会将其内联为以下内容:

../_images/call_hook_node.png

Compiled Autograd 的常见重新编译原因

  1. 由于损失值的自动梯度结构发生了变化:

torch._dynamo.config.compiled_autograd = True
x = torch.randn(10, requires_grad=True)
for op in [torch.add, torch.sub, torch.mul, torch.div]:
   loss = op(x, x).sum()
   torch.compile(lambda: loss.backward(), backend="eager")()

在上面的例子中,我们在每次迭代时调用不同的操作符,导致loss每次跟踪不同的自动梯度历史。你应该会看到一些重新编译的消息:由于新的自动梯度节点导致缓存未命中

../_images/recompile_due_to_node.png
  1. 由于张量形状的变化:

torch._dynamo.config.compiled_autograd = True
for i in [10, 100, 10]:
   x = torch.randn(i, i, requires_grad=True)
   loss = x.sum()
   torch.compile(lambda: loss.backward(), backend="eager")()

在上面的例子中,x 改变了形状,编译后的自动梯度将在第一次改变后将 x 标记为动态形状张量。你应该会看到重新编译的消息:由于形状改变导致的缓存未命中

../_images/recompile_due_to_dynamic.png

结论

在本教程中,我们概述了带有编译自动微分的torch.compile的高级生态系统,编译自动微分的基础知识以及一些常见的重新编译原因。请继续关注dev-discuss上的深入探讨。