自定义编译器

警告

自定义管道功能仅供专家使用。修改编译器行为可能会使numba源代码中的内部假设失效。

对于寻求扩展或修改编译器行为的库开发者,可以通过继承 numba.compiler.CompilerBase 来定义一个自定义编译器。默认的 Numba 编译器定义为 numba.compiler.Compiler,实现了 .define_pipelines() 方法,该方法添加了 nopython-modeobject-modeinterpreted-mode 管道。为了方便,这三个管道在 numba.compiler.DefaultPassBuilder 中通过以下方法定义:

  • .define_nopython_pipeline()

  • .define_objectmode_pipeline()

  • .define_interpreted_pipeline()

分别地。

要使用 CompilerBase 的自定义子类,请将其作为 pipeline_class 关键字参数提供给 @jit 装饰器。通过这样做,自定义管道的效应仅限于被装饰的函数。

实现编译器传递

Numba 使得实现一个新的编译器传递成为可能,并且通过使用类似于 LLVM 的 API 来实现这一点。以下展示了所涉及的基本过程。

编译器传递类

所有通道必须继承自 numba.compiler_machinery.CompilerPass,常用的子类有:

  • numba.compiler_machinery.FunctionPass 用于描述一个在函数级别上操作并可能改变IR状态的传递。

  • numba.compiler_machinery.AnalysisPass 用于描述仅执行分析的传递。

  • numba.compiler_machinery.LoweringPass 用于描述仅执行降低操作的传递。

在这个例子中,将实现一个新的编译器传递,该传递将重写所有 ir.Const(x) 节点,其中 xnumbers.Number 的子类,使得 x 的值增加一。这个传递除了作为教学工具之外,没有其他用途!

numba.compiler_machinery.FunctionPass 适用于建议的传递行为,因此是新传递的基类。此外,定义了一个 run_pass 方法来执行工作(此方法是抽象的,所有编译器传递都必须实现它)。

首先是新类:

from numba import njit
from numba.core import ir
from numba.core.compiler import CompilerBase, DefaultPassBuilder
from numba.core.compiler_machinery import FunctionPass, register_pass
from numba.core.untyped_passes import IRProcessing
from numbers import Number

# Register this pass with the compiler framework, declare that it will not
# mutate the control flow graph and that it is not an analysis_only pass (it
# potentially mutates the IR).
@register_pass(mutates_CFG=False, analysis_only=False)
class ConstsAddOne(FunctionPass):
    _name = "consts_add_one" # the common name for the pass

    def __init__(self):
        FunctionPass.__init__(self)

    # implement method to do the work, "state" is the internal compiler
    # state from the CompilerBase instance.
    def run_pass(self, state):
        func_ir = state.func_ir # get the FunctionIR object
        mutated = False # used to record whether this pass mutates the IR
        # walk the blocks
        for blk in func_ir.blocks.values():
            # find the assignment nodes in the block and walk them
            for assgn in blk.find_insts(ir.Assign):
                # if an assignment value is a ir.Consts
                if isinstance(assgn.value, ir.Const):
                    const_val = assgn.value
                    # if the value of the ir.Const is a Number
                    if isinstance(const_val.value, Number):
                        # then add one!
                        const_val.value += 1
                        mutated |= True
        return mutated # return True if the IR was mutated, False if not.

另请注意,该类必须使用 @register_pass 注册到 Numba 的编译器机制中。这在一定程度上是为了允许声明该过程是否改变了控制流图,以及它是否只是一个分析过程。

接下来,基于现有的 numba.compiler.CompilerBase 定义一个新的编译器。编译器流水线通过使用现有的流水线定义,并将上面声明的新过程添加到 IRProcessing 过程之后运行。

class MyCompiler(CompilerBase): # custom compiler extends from CompilerBase

    def define_pipelines(self):
        # define a new set of pipelines (just one in this case) and for ease
        # base it on an existing pipeline from the DefaultPassBuilder,
        # namely the "nopython" pipeline
        pm = DefaultPassBuilder.define_nopython_pipeline(self.state)
        # Add the new pass to run after IRProcessing
        pm.add_pass_after(ConstsAddOne, IRProcessing)
        # finalize
        pm.finalize()
        # return as an iterable, any number of pipelines may be defined!
        return [pm]

最后,在调用点更新 @njit 装饰器,以利用新定义的编译管道。

@njit(pipeline_class=MyCompiler) # JIT compile using the custom compiler
def foo(x):
    a = 10
    b = 20.2
    c = x + a + b
    return c

print(foo(100)) # 100 + 10 + 20.2 (+ 1 + 1), extra + 1 + 1 from the rewrite!

调试编译器传递

观察IR变化

能够查看一个pass对IR所做的更改通常很有用。Numba通过使用环境变量 NUMBA_DEBUG_PRINT_AFTER 方便地允许这一点。对于上述pass,使用 NUMBA_DEBUG_PRINT_AFTER="ir_processing,consts_add_one" 运行示例代码会得到:

----------------------------nopython: ir_processing-----------------------------
label 0:
    x = arg(0, name=x)                       ['x']
    $const0.1 = const(int, 10)               ['$const0.1']
    a = $const0.1                            ['$const0.1', 'a']
    del $const0.1                            []
    $const0.2 = const(float, 20.2)           ['$const0.2']
    b = $const0.2                            ['$const0.2', 'b']
    del $const0.2                            []
    $0.5 = x + a                             ['$0.5', 'a', 'x']
    del x                                    []
    del a                                    []
    $0.7 = $0.5 + b                          ['$0.5', '$0.7', 'b']
    del b                                    []
    del $0.5                                 []
    c = $0.7                                 ['$0.7', 'c']
    del $0.7                                 []
    $0.9 = cast(value=c)                     ['$0.9', 'c']
    del c                                    []
    return $0.9                              ['$0.9']
----------------------------nopython: consts_add_one----------------------------
label 0:
    x = arg(0, name=x)                       ['x']
    $const0.1 = const(int, 11)               ['$const0.1']
    a = $const0.1                            ['$const0.1', 'a']
    del $const0.1                            []
    $const0.2 = const(float, 21.2)           ['$const0.2']
    b = $const0.2                            ['$const0.2', 'b']
    del $const0.2                            []
    $0.5 = x + a                             ['$0.5', 'a', 'x']
    del x                                    []
    del a                                    []
    $0.7 = $0.5 + b                          ['$0.5', '$0.7', 'b']
    del b                                    []
    del $0.5                                 []
    c = $0.7                                 ['$0.7', 'c']
    del $0.7                                 []
    $0.9 = cast(value=c)                     ['$0.9', 'c']
    del c                                    []
    return $0.9                              ['$0.9']

注意 const 节点中值的变化。

传递执行时间

Numba 内置支持计时所有编译阶段,执行时间存储在与编译结果关联的元数据中。这展示了一种基于先前定义的函数 foo 访问此信息的方法:

compile_result = foo.overloads[foo.signatures[0]]
nopython_times = compile_result.metadata['pipeline_times']['nopython']
for k in nopython_times.keys():
    if ConstsAddOne._name in k:
        print(nopython_times[k])

其输出结果例如:

pass_timings(init=1.914000677061267e-06, run=4.308700044930447e-05, finalize=1.7400006981915794e-06)

这显示了初始化、运行和最终化时间的秒数。