内联说明
在某些情况下,能够在调用点内联函数在Numba IR表示级别上是有用的。诸如 numba.jit()
、numba.extending.overload()
和 register_jitable()
等装饰器支持关键字参数 inline
,以促进这种行为。
在尝试在此级别内联时,重要的是要理解这服务的目的是什么以及这会产生什么效果。与LLVM旨在提高性能的内联不同,在Numba IR级别内联的主要原因是允许类型推断跨越函数边界。
作为一个例子,考虑以下代码片段:
from numba import njit
@njit
def bar(a):
a.append(10)
@njit
def foo():
z = []
bar(z)
foo()
这将无法编译和运行,因为 z
的类型无法推断,因为它只会在 bar
内部被细化。如果我们现在在 bar
的装饰器中添加 inline=True
,代码片段将编译并运行。这是因为内联调用 a.append(10)
意味着 z
将被细化为持有整数,因此类型推断将成功。
因此,总结一下,在 Numba IR 级别内联不太可能带来性能提升。而在 LLVM 级别内联则更有可能。
inline
关键字参数可以是以下三个值之一:
字符串
'never'
是默认值,在任何情况下都不会内联该函数。字符串
'always'
,这会导致函数在所有调用点内联。一个接受三个参数的Python函数。第一个参数始终是请求内联的
call
的ir.Expr
节点,这存在是为了允许函数根据调用上下文做出决策。第二个和第三个参数是:在无类型内联的情况下,即在使用
numba.jit()
系列装饰器时,两个参数都是numba.ir.FunctionIR
实例。第二个参数对应于调用者的IR,第三个参数对应于被调用者的IR。在类型化内联的情况下,即在使用
numba.extending.overload()
时发生的情况,两个参数都是namedtuple
的实例,具有字段(对应于它们在编译器内部的标准使用):func_ir
- 函数的 Numba IR。typemap
- 函数的类型映射。calltypes
- 函数中任何调用的调用类型。signature
- 函数的签名。
第二个参数包含调用者的信息,第三个参数包含被调用者的信息。
在所有情况下,函数应返回 True 以内联,返回 False 以不内联,这本质上允许自定义内联规则(典型用途可能是成本模型)。
使用
inline='always'
的递归函数将导致编译无法终止。如果希望避免这种情况,请提供一个函数来限制递归深度(见下文)。
备注
关于函数被评估内联的顺序或它们被内联的顺序,没有任何保证。
使用 numba.jit()
的示例
在 numba.njit()
装饰器中使用所有三个选项的 inline
示例:
from numba import njit
import numba
from numba.core import ir
@njit(inline='never')
def never_inline():
return 100
@njit(inline='always')
def always_inline():
return 200
def sentinel_cost_model(expr, caller_info, callee_info):
# this cost model will return True (i.e. do inlining) if either:
# a) the callee IR contains an `ir.Const(37)`
# b) the caller IR contains an `ir.Const(13)` logically prior to the call
# site
# check the callee
for blk in callee_info.blocks.values():
for stmt in blk.body:
if isinstance(stmt, ir.Assign):
if isinstance(stmt.value, ir.Const):
if stmt.value.value == 37:
return True
# check the caller
before_expr = True
for blk in caller_info.blocks.values():
for stmt in blk.body:
if isinstance(stmt, ir.Assign):
if isinstance(stmt.value, ir.Expr):
if stmt.value == expr:
before_expr = False
if isinstance(stmt.value, ir.Const):
if stmt.value.value == 13:
return True & before_expr
return False
@njit(inline=sentinel_cost_model)
def maybe_inline1():
# Will not inline based on the callee IR with the declared cost model
# The following is ir.Const(300).
return 300
@njit(inline=sentinel_cost_model)
def maybe_inline2():
# Will inline based on the callee IR with the declared cost model
# The following is ir.Const(37).
return 37
@njit
def foo():
a = never_inline() # will never inline
b = always_inline() # will always inline
# will not inline as the function does not contain a magic constant known to
# the cost model, and the IR up to the call site does not contain a magic
# constant either
d = maybe_inline1()
# declare this magic constant to trigger inlining of maybe_inline1 in a
# subsequent call
magic_const = 13
# will inline due to above constant declaration
e = maybe_inline1()
# will inline as the maybe_inline2 function contains a magic constant known
# to the cost model
c = maybe_inline2()
return a + b + c + d + e + magic_const
foo()
执行时生成以下内容(通过环境变量 NUMBA_DEBUG_PRINT_AFTER="ir_legalization"
启用,打印法律化后的IR):
label 0:
$0.1 = global(never_inline: CPUDispatcher(<function never_inline at 0x7f890ccf9048>)) ['$0.1']
$0.2 = call $0.1(func=$0.1, args=[], kws=(), vararg=None) ['$0.1', '$0.2']
del $0.1 []
a = $0.2 ['$0.2', 'a']
del $0.2 []
$0.3 = global(always_inline: CPUDispatcher(<function always_inline at 0x7f890ccf9598>)) ['$0.3']
del $0.3 []
$const0.1.0 = const(int, 200) ['$const0.1.0']
$0.2.1 = $const0.1.0 ['$0.2.1', '$const0.1.0']
del $const0.1.0 []
$0.4 = $0.2.1 ['$0.2.1', '$0.4']
del $0.2.1 []
b = $0.4 ['$0.4', 'b']
del $0.4 []
$0.5 = global(maybe_inline1: CPUDispatcher(<function maybe_inline1 at 0x7f890ccf9ae8>)) ['$0.5']
$0.6 = call $0.5(func=$0.5, args=[], kws=(), vararg=None) ['$0.5', '$0.6']
del $0.5 []
d = $0.6 ['$0.6', 'd']
del $0.6 []
$const0.7 = const(int, 13) ['$const0.7']
magic_const = $const0.7 ['$const0.7', 'magic_const']
del $const0.7 []
$0.8 = global(maybe_inline1: CPUDispatcher(<function maybe_inline1 at 0x7f890ccf9ae8>)) ['$0.8']
del $0.8 []
$const0.1.2 = const(int, 300) ['$const0.1.2']
$0.2.3 = $const0.1.2 ['$0.2.3', '$const0.1.2']
del $const0.1.2 []
$0.9 = $0.2.3 ['$0.2.3', '$0.9']
del $0.2.3 []
e = $0.9 ['$0.9', 'e']
del $0.9 []
$0.10 = global(maybe_inline2: CPUDispatcher(<function maybe_inline2 at 0x7f890ccf9b70>)) ['$0.10']
del $0.10 []
$const0.1.4 = const(int, 37) ['$const0.1.4']
$0.2.5 = $const0.1.4 ['$0.2.5', '$const0.1.4']
del $const0.1.4 []
$0.11 = $0.2.5 ['$0.11', '$0.2.5']
del $0.2.5 []
c = $0.11 ['$0.11', 'c']
del $0.11 []
$0.14 = a + b ['$0.14', 'a', 'b']
del b []
del a []
$0.16 = $0.14 + c ['$0.14', '$0.16', 'c']
del c []
del $0.14 []
$0.18 = $0.16 + d ['$0.16', '$0.18', 'd']
del d []
del $0.16 []
$0.20 = $0.18 + e ['$0.18', '$0.20', 'e']
del e []
del $0.18 []
$0.22 = $0.20 + magic_const ['$0.20', '$0.22', 'magic_const']
del magic_const []
del $0.20 []
$0.23 = cast(value=$0.22) ['$0.22', '$0.23']
del $0.22 []
return $0.23 ['$0.23']
需要注意的事项:
对函数
never_inline
的调用保持为调用。always_inline
函数已被内联,注意其在调用者主体中的const(int, 200)
。在
const(int, 13)
声明之前有一个对maybe_inline1
的调用,成本模型阻止了这次内联。在
const(int, 13)
之后,对maybe_inline1
的后续调用已被内联,如调用者主体中的const(int, 300)
所示。函数
maybe_inline2
已被内联,如调用者主体中的const(int, 37)
所示。死代码消除尚未执行,因此IR中存在多余的语句。
使用示例 numba.extending.overload()
使用 numba.extending.overload()
装饰器进行内联的示例。值得注意的是,如果将函数作为 inline
的参数提供,则可以通过提供的函数参数获得更多信息,以便用于决策。此外,不同的 @overload
可以有不同的内联行为,有多种方法可以实现这一点:
import numba
from numba.extending import overload
from numba import njit, types
def bar(x):
"""A function stub to overload"""
pass
@overload(bar, inline='always')
def ol_bar_tuple(x):
# An overload that will always inline, there is a type guard so that this
# only applies to UniTuples.
if isinstance(x, types.UniTuple):
def impl(x):
return x[0]
return impl
def cost_model(expr, caller, callee):
# Only inline if the type of the argument is an Integer
return isinstance(caller.typemap[expr.args[0].name], types.Integer)
@overload(bar, inline=cost_model)
def ol_bar_scalar(x):
# An overload that will inline based on a cost model, it only applies to
# scalar values in the numerical domain as per the type guard on Number
if isinstance(x, types.Number):
def impl(x):
return x + 1
return impl
@njit
def foo():
# This will resolve via `ol_bar_tuple` as the argument is a types.UniTuple
# instance. It will always be inlined as specified in the decorator for this
# overload.
a = bar((1, 2, 3))
# This will resolve via `ol_bar_scalar` as the argument is a types.Number
# instance, hence the cost_model will be used to determine whether to
# inline.
# The function will be inlined as the value 100 is an IntegerLiteral which
# is an instance of a types.Integer as required by the cost_model function.
b = bar(100)
# This will also resolve via `ol_bar_scalar` as the argument is a
# types.Number instance, again the cost_model will be used to determine
# whether to inline.
# The function will not be inlined as the complex value is not an instance
# of a types.Integer as required by the cost_model function.
c = bar(300j)
return a + b + c
foo()
执行时生成以下内容(通过环境变量 NUMBA_DEBUG_PRINT_AFTER="ir_legalization"
启用,打印法律化后的IR):
label 0:
$const0.2 = const(tuple, (1, 2, 3)) ['$const0.2']
x.0 = $const0.2 ['$const0.2', 'x.0']
del $const0.2 []
$const0.2.2 = const(int, 0) ['$const0.2.2']
$0.3.3 = getitem(value=x.0, index=$const0.2.2) ['$0.3.3', '$const0.2.2', 'x.0']
del x.0 []
del $const0.2.2 []
$0.4.4 = $0.3.3 ['$0.3.3', '$0.4.4']
del $0.3.3 []
$0.3 = $0.4.4 ['$0.3', '$0.4.4']
del $0.4.4 []
a = $0.3 ['$0.3', 'a']
del $0.3 []
$const0.5 = const(int, 100) ['$const0.5']
x.5 = $const0.5 ['$const0.5', 'x.5']
del $const0.5 []
$const0.2.7 = const(int, 1) ['$const0.2.7']
$0.3.8 = x.5 + $const0.2.7 ['$0.3.8', '$const0.2.7', 'x.5']
del x.5 []
del $const0.2.7 []
$0.4.9 = $0.3.8 ['$0.3.8', '$0.4.9']
del $0.3.8 []
$0.6 = $0.4.9 ['$0.4.9', '$0.6']
del $0.4.9 []
b = $0.6 ['$0.6', 'b']
del $0.6 []
$0.7 = global(bar: <function bar at 0x7f6c3710d268>) ['$0.7']
$const0.8 = const(complex, 300j) ['$const0.8']
$0.9 = call $0.7($const0.8, func=$0.7, args=[Var($const0.8, inline_overload_example.py (56))], kws=(), vararg=None) ['$0.7', '$0.9', '$const0.8']
del $const0.8 []
del $0.7 []
c = $0.9 ['$0.9', 'c']
del $0.9 []
$0.12 = a + b ['$0.12', 'a', 'b']
del b []
del a []
$0.14 = $0.12 + c ['$0.12', '$0.14', 'c']
del c []
del $0.12 []
$0.15 = cast(value=$0.14) ['$0.14', '$0.15']
del $0.14 []
return $0.15 ['$0.15']
需要注意的事项:
第一个高亮部分是
UniTuple
参数类型的总是内联的重载。第二个高亮部分是
Number
参数类型的重载,由于成本模型函数决定这样做,因为参数是一个Integer
类型的实例,所以它被内联了。第三个高亮部分是
Number
参数类型的重载,由于成本模型函数决定拒绝它,因为参数是一个Complex
类型实例,所以没有内联。死代码消除尚未执行,因此IR中存在多余的语句。
使用一个函数来限制递归函数的内联深度
在使用递归内联时,你可以通过使用成本模型来终止编译。
from numba import njit
import numpy as np
class CostModel(object):
def __init__(self, max_inlines):
self._count = 0
self._max_inlines = max_inlines
def __call__(self, expr, caller, callee):
ret = self._count < self._max_inlines
self._count += 1
return ret
@njit(inline=CostModel(3))
def factorial(n):
if n <= 0:
return 1
return n * factorial(n - 1)
factorial(5)