使用Numba重写传递进行优化和乐趣

概述

本节介绍中间表示(IR)重写,以及如何使用它们来实现优化。

如在“阶段 5a:重写类型化 IR”中之前讨论的,重写 Numba IR 使我们能够执行在较低的 LLVM 级别上执行起来要困难得多的优化。类似于 Numba 的类型和降低子系统,重写子系统也是用户可扩展的。这种扩展性赋予了 Numba 支持各种领域特定优化(DSO)的可能性。

剩余的小节详细介绍了实现重写的机制,向重写注册表注册重写的方法,并提供了添加新重写的示例,以及数组表达式优化传递的内部机制。最后,我们将回顾示例中暴露的一些用例,并审查开发人员应注意的任何要点。

重写过程

重写过程具有简单的 match()apply() 接口。匹配和重写的划分遵循在声明性领域特定语言(DSL)中定义术语重写的方式。在这样的DSL中,可以如下编写重写:

<match> => <replacement>

<match><replacement> 符号表示 IR 项表达式,其中左侧表示要匹配的模式,右侧表示基于匹配构建的 IR 项构造器。每当重写匹配到 IR 模式时,左侧中的任何自由变量都会在自定义环境中绑定。应用时,重写使用模式匹配环境来绑定右侧中的任何自由变量。

由于Python通常不用于声明性功能,Numba使用对象状态来处理匹配和应用步骤之间的信息传递。

The Rewrite 基类

class Rewrite

The Rewrite 类简单地定义了Numba重写的抽象基类。开发者应将重写定义为此基类的子类,并重载 match()apply() 方法。

pipeline

pipeline 属性包含当前正在为重写考虑的函数进行编译的 numba.compiler.Pipeline 实例。

__init__(self, pipeline, *args, **kws)

重写的基构造函数只是将其参数存储在同名属性中。除非用于调试或测试,否则重写应仅由 RewriteRegistryRewriteRegistry.apply() 方法中构造,并且构造接口应保持稳定(尽管管道通常会包含几乎所有已知的内容)。

match(self, block, typemap, callmap)

The match() 方法除了 self 之外,还需要四个参数:

  • func_ir: 这是正在重写的函数的 numba.ir.FunctionIR 实例。

  • block: 这是一个 numba.ir.Block 的实例。匹配方法应遍历 numba.ir.Block.body 成员中包含的指令。

  • typemap: 这是一个 Python 字典 实例,从 IR 中的符号名称(以字符串表示)映射到 Numba 类型。

  • callmap: 这是另一个 dict 实例,它从调用映射,表示为 numba.ir.Expr 实例,到它们相应的调用点类型签名,表示为 numba.typing.templates.Signature 实例。

The match() method should return a bool result. A True result should indicate that one or more matches were found, and the apply() method will return a new replacement numba.ir.Block instance. A False result should indicate that no matches were found, and subsequent calls to apply() will return undefined or invalid results.

apply(self)

只有在成功调用 match() 之后,才应调用 apply() 方法。此方法除了 self 之外不接受其他参数,并且应返回一个替换的 numba.ir.Block 实例。

As mentioned above, the behavior of calling apply() is undefined unless match() has already been called and returned True.

子类化 Rewrite

在讨论任何 Rewrite 子类必须重载的方法的期望之前,让我们先退一步,回顾一下这里正在发生的事情。通过提供一个可扩展的编译器,Numba 使自己能够接受用户定义的代码生成器,这些生成器可能是不完整的,或者更糟,是错误的。当一个代码生成器出错时,它可能导致程序行为异常或提前终止。用户定义的重写增加了新的复杂性,因为它们不仅要生成正确的代码,而且它们生成的代码应该确保编译器不会陷入匹配/应用循环中。编译器的非终止将直接导致用户函数调用的非终止。

有几种方法可以帮助确保重写终止:

  • 类型: 重写通常应尝试分解复合类型,并避免组合新类型。如果重写匹配特定类型,将表达式类型更改为较低级别的类型将确保在应用重写后它们不再匹配。

  • 特殊说明:重写可能会合成自定义操作符或在目标IR中使用特殊函数。这种技术再次生成的代码不再属于原始匹配的领域,重写将终止。

在下面的“案例研究:数组表达式”小节中,我们将看到数组表达式重写器如何使用这两种技术。

重载 Rewrite.match()

Every rewrite developer should seek to have their implementation of match() return a False value as quickly as possible. Numba is a just-in-time compiler, and adding compilation time ultimately adds to the user’s run time. When a rewrite returns False for a given block, the registry will no longer process that block with that rewrite, and the compiler is that much closer to proceeding to lowering.

这种对及时性的需求必须与收集必要信息以进行重写匹配的需求相平衡。重写开发者应该能够自如地为他们的子类添加动态属性,然后让这些新属性指导替换基本块的构建。

重载 Rewrite.apply()

The apply() method should return a replacement numba.ir.Block instance to replace the basic block that contained a match for the rewrite. As mentioned above, the IR built by apply() methods should preserve the semantics of the user’s code, but also seek to avoid generating another match for the same rewrite or set of rewrites.

重写注册表

当你想在重写过程中包含一个重写时,你应该将其注册到重写注册表中。numba.rewrites 模块提供了抽象基类和类装饰器,用于钩入 Numba 重写子系统。以下展示了一个新重写的存根定义:

from numba import rewrites

@rewrites.register_rewrite
class MyRewrite(rewrites.Rewrite):

    def match(self, block, typemap, calltypes):
        raise NotImplementedError("FIXME")

    def apply(self):
        raise NotImplementedError("FIXME")

开发者应注意,如上所示使用类装饰器将在导入时注册一个重写。开发者有责任确保他们的扩展在编译开始之前被加载。

案例研究:数组表达式

本小节将更深入地探讨数组表达式重写器。数组表达式重写器及其大部分支持功能位于 numba.npyufunc.array_exprs 模块中。重写过程本身在 RewriteArrayExprs 类中实现。除了重写器之外,array_exprs 模块还包括一个用于降低数组表达式的函数,即 _lower_array_expr()。整体优化过程如下:

  • RewriteArrayExprs.match(): 重写过程寻找一个或多个形成数组表达式的数组操作。

  • RewriteArrayExprs.apply(): 一旦找到数组表达式,重写器会将各个数组操作替换为一种新的IR表达式,即 arrayexpr

  • numba.npyufunc.array_exprs._lower_array_expr():在降低过程中,代码生成器在发现 arrayexpr IR 表达式时调用 _lower_array_expr()

优化过程的每一步详细信息如下。

The RewriteArrayExprs.match() 方法

数组表达式优化过程首先寻找数组操作,包括对支持的 ufunc 和用户定义的 DUFunc 的调用。Numba IR遵循静态单赋值(SSA)语言的约定,这意味着寻找数组操作符的过程从查找赋值指令开始。

当重写过程调用 RewriteArrayExprs.match() 方法时,它首先检查是否可以简单地拒绝该基本块。如果该方法确定该块是匹配的候选者,它会在重写对象中设置以下状态变量:

  • crnt_block: 当前正在匹配的基本块。

  • typemap: 正在匹配的函数的 typemap

  • matches: 一个引用数组表达式的变量名称列表。

  • array_assigns: 一个从赋值变量名到实际赋值指令的映射,这些指令定义了给定的变量。

  • const_assigns: 一个从赋值变量名到定义常量变量的常量值表达式的映射。

此时,匹配方法会遍历输入基本块中的赋值指令。对于每个赋值指令,匹配器会寻找以下两种情况之一:

  • 数组操作:如果赋值指令的右侧是一个表达式,并且该表达式的结果是数组类型,匹配器会检查该表达式是否是已知的数组操作,或者是调用通用函数。如果找到数组操作符,匹配器会将左侧的变量名和整个指令存储在 array_assigns 成员中。最后,匹配器会测试数组操作的任何操作数是否也被识别为其他数组操作的目标。如果一个或多个操作数也是数组操作的目标,那么匹配器也会将左侧的变量名附加到 matches 成员中。

  • 常量:常量(即使是标量)也可以作为数组操作的操作数。在不考虑常量是否为数组表达式的一部分的情况下,匹配器将常量名称和值存储在 const_assigns 成员中。

The end of the matching method simply checks for a non-empty matches list, returning True if there were one or more matches, and False when matches is empty.

The RewriteArrayExprs.apply() 方法

当通过 RewriteArrayExprs.match() 找到一个或匹配的数组表达式时,重写过程将调用 RewriteArrayExprs.apply()。apply 方法分两步进行。第一步遍历找到的匹配项,并在新基本块中从旧基本块的指令构建到新指令的映射。第二步遍历旧基本块中的指令,复制未被重写更改的指令,并替换或删除第一步中识别的指令。

The RewriteArrayExprs._handle_matches() 实现了重写代码生成部分的第一遍。对于每个匹配项,此方法构建一个包含数组表达式的表达式树的特殊IR表达式。为了计算表达式树的叶子,_handle_matches() 方法遍历已识别的根操作的操作数。如果操作数是另一个数组操作,则将其转换为表达式子树。如果操作数是常量,_handle_matches() 复制常量值。否则,操作数被标记为由数组表达式使用。当方法构建数组表达式节点时,它会构建从旧指令到新指令的映射(replace_map),以及可能已移动的变量集(used_vars),以及应完全删除的变量集(dead_vars)。这三个数据结构被返回给调用的 RewriteArrayExprs.apply() 方法。

剩余的 RewriteArrayExprs.apply() 方法遍历旧基本块中的指令。对于每条指令,该方法根据 RewriteArrayExprs._handle_matches() 的结果替换、删除或复制该指令。以下列表描述了优化如何处理单个指令:

  • When an instruction is an assignment, apply() checks to see if it is in the replacement instruction map. When an assignment instruction is found in the instruction map, apply() must then check to see if the replacement instruction is also in the replacement map. The optimizer continues this check until it either arrives at a None value or an instruction that isn’t in the replacement map. Instructions that have a replacement that is None are deleted. Instructions that have a non-None replacement are replaced. Assignment instructions not in the replacement map are appended to the new basic block with no changes made.

  • 当指令是删除指令时,重写检查它是否删除了一个可能仍被后续数组表达式使用的变量,或者是否删除了一个死变量。对于使用的变量的删除指令被添加到一个延迟删除指令的映射中,该映射由 apply() 使用,以将它们移动到该变量的任何使用之后。循环复制非死变量的删除指令,并忽略死变量的删除指令(有效地将其从基本块中移除)。

  • 所有其他指令都被附加到新的基本块中。

最后,apply() 方法返回用于降低的新基本块。

_lower_array_expr() 函数

如果我们仅仅停留在重写阶段,那么编译器的降低阶段将会失败,抱怨它不知道如何降低 arrayexpr 操作。我们首先在每次编译器实例化 RewriteArrayExprs 类时,将一个降低函数挂钩到目标上下文中。这个挂钩导致降低过程在遇到 arrayexpr 操作时调用 _lower_array_expr()

此函数有两个步骤:

  • 合成一个实现数组表达式的Python函数:这个新的Python函数本质上表现得像一个Numpy ufunc,返回在广播数组参数中标量值的表达式结果。降低函数通过将数组表达式树转换为Python AST来实现这一点。

  • 将合成 Python 函数编译为内核:此时,降低函数依赖于现有的降低 ufunc 和 DUFunc 内核的代码,在定义了如何降低对合成函数的调用后,调用 numba.targets.numpyimpl.numpy_ufunc_kernel()

最终结果类似于 Numba 对象模式中的循环提升。

结论与注意事项

我们已经了解了如何在 Numba 中实现重写,从接口开始,到实际的优化结束。本节的关键点是:

  • 在编写一个好的插件时,匹配器应尽量尽快获得一个通过/不通过的结果。

  • 重写应用程序部分可能会更加计算密集,但仍应生成不会导致编译器中出现无限循环的代码。

  • 我们使用对象状态来向重写应用程序传递匹配的任何结果。