自定义编译器
警告
自定义管道功能仅供专家使用。修改编译器行为可能会使numba源代码中的内部假设失效。
对于寻求扩展或修改编译器行为的库开发者,可以通过继承 numba.compiler.CompilerBase
来定义一个自定义编译器。默认的 Numba 编译器定义为 numba.compiler.Compiler
,实现了 .define_pipelines()
方法,该方法添加了 nopython-mode、object-mode 和 interpreted-mode 管道。为了方便,这三个管道在 numba.compiler.DefaultPassBuilder
中通过以下方法定义:
.define_nopython_pipeline()
.define_objectmode_pipeline()
.define_interpreted_pipeline()
分别地。
要使用 CompilerBase
的自定义子类,请将其作为 pipeline_class
关键字参数提供给 @jit
装饰器。通过这样做,自定义管道的效应仅限于被装饰的函数。
实现编译器传递
Numba 使得实现一个新的编译器传递成为可能,并且通过使用类似于 LLVM 的 API 来实现这一点。以下展示了所涉及的基本过程。
编译器传递类
所有通道必须继承自 numba.compiler_machinery.CompilerPass
,常用的子类有:
numba.compiler_machinery.FunctionPass
用于描述一个在函数级别上操作并可能改变IR状态的传递。numba.compiler_machinery.AnalysisPass
用于描述仅执行分析的传递。numba.compiler_machinery.LoweringPass
用于描述仅执行降低操作的传递。
在这个例子中,将实现一个新的编译器传递,该传递将重写所有 ir.Const(x)
节点,其中 x
是 numbers.Number
的子类,使得 x 的值增加一。这个传递除了作为教学工具之外,没有其他用途!
numba.compiler_machinery.FunctionPass
适用于建议的传递行为,因此是新传递的基类。此外,定义了一个 run_pass
方法来执行工作(此方法是抽象的,所有编译器传递都必须实现它)。
首先是新类:
from numba import njit
from numba.core import ir
from numba.core.compiler import CompilerBase, DefaultPassBuilder
from numba.core.compiler_machinery import FunctionPass, register_pass
from numba.core.untyped_passes import IRProcessing
from numbers import Number
# Register this pass with the compiler framework, declare that it will not
# mutate the control flow graph and that it is not an analysis_only pass (it
# potentially mutates the IR).
@register_pass(mutates_CFG=False, analysis_only=False)
class ConstsAddOne(FunctionPass):
_name = "consts_add_one" # the common name for the pass
def __init__(self):
FunctionPass.__init__(self)
# implement method to do the work, "state" is the internal compiler
# state from the CompilerBase instance.
def run_pass(self, state):
func_ir = state.func_ir # get the FunctionIR object
mutated = False # used to record whether this pass mutates the IR
# walk the blocks
for blk in func_ir.blocks.values():
# find the assignment nodes in the block and walk them
for assgn in blk.find_insts(ir.Assign):
# if an assignment value is a ir.Consts
if isinstance(assgn.value, ir.Const):
const_val = assgn.value
# if the value of the ir.Const is a Number
if isinstance(const_val.value, Number):
# then add one!
const_val.value += 1
mutated |= True
return mutated # return True if the IR was mutated, False if not.
另请注意,该类必须使用 @register_pass
注册到 Numba 的编译器机制中。这在一定程度上是为了允许声明该过程是否改变了控制流图,以及它是否只是一个分析过程。
接下来,基于现有的 numba.compiler.CompilerBase
定义一个新的编译器。编译器流水线通过使用现有的流水线定义,并将上面声明的新过程添加到 IRProcessing
过程之后运行。
class MyCompiler(CompilerBase): # custom compiler extends from CompilerBase
def define_pipelines(self):
# define a new set of pipelines (just one in this case) and for ease
# base it on an existing pipeline from the DefaultPassBuilder,
# namely the "nopython" pipeline
pm = DefaultPassBuilder.define_nopython_pipeline(self.state)
# Add the new pass to run after IRProcessing
pm.add_pass_after(ConstsAddOne, IRProcessing)
# finalize
pm.finalize()
# return as an iterable, any number of pipelines may be defined!
return [pm]
最后,在调用点更新 @njit
装饰器,以利用新定义的编译管道。
@njit(pipeline_class=MyCompiler) # JIT compile using the custom compiler
def foo(x):
a = 10
b = 20.2
c = x + a + b
return c
print(foo(100)) # 100 + 10 + 20.2 (+ 1 + 1), extra + 1 + 1 from the rewrite!
调试编译器传递
观察IR变化
能够查看一个pass对IR所做的更改通常很有用。Numba通过使用环境变量 NUMBA_DEBUG_PRINT_AFTER
方便地允许这一点。对于上述pass,使用 NUMBA_DEBUG_PRINT_AFTER="ir_processing,consts_add_one"
运行示例代码会得到:
----------------------------nopython: ir_processing-----------------------------
label 0:
x = arg(0, name=x) ['x']
$const0.1 = const(int, 10) ['$const0.1']
a = $const0.1 ['$const0.1', 'a']
del $const0.1 []
$const0.2 = const(float, 20.2) ['$const0.2']
b = $const0.2 ['$const0.2', 'b']
del $const0.2 []
$0.5 = x + a ['$0.5', 'a', 'x']
del x []
del a []
$0.7 = $0.5 + b ['$0.5', '$0.7', 'b']
del b []
del $0.5 []
c = $0.7 ['$0.7', 'c']
del $0.7 []
$0.9 = cast(value=c) ['$0.9', 'c']
del c []
return $0.9 ['$0.9']
----------------------------nopython: consts_add_one----------------------------
label 0:
x = arg(0, name=x) ['x']
$const0.1 = const(int, 11) ['$const0.1']
a = $const0.1 ['$const0.1', 'a']
del $const0.1 []
$const0.2 = const(float, 21.2) ['$const0.2']
b = $const0.2 ['$const0.2', 'b']
del $const0.2 []
$0.5 = x + a ['$0.5', 'a', 'x']
del x []
del a []
$0.7 = $0.5 + b ['$0.5', '$0.7', 'b']
del b []
del $0.5 []
c = $0.7 ['$0.7', 'c']
del $0.7 []
$0.9 = cast(value=c) ['$0.9', 'c']
del c []
return $0.9 ['$0.9']
注意 const
节点中值的变化。
传递执行时间
Numba 内置支持计时所有编译阶段,执行时间存储在与编译结果关联的元数据中。这展示了一种基于先前定义的函数 foo
访问此信息的方法:
compile_result = foo.overloads[foo.signatures[0]]
nopython_times = compile_result.metadata['pipeline_times']['nopython']
for k in nopython_times.keys():
if ConstsAddOne._name in k:
print(nopython_times[k])
其输出结果例如:
pass_timings(init=1.914000677061267e-06, run=4.308700044930447e-05, finalize=1.7400006981915794e-06)
这显示了初始化、运行和最终化时间的秒数。