内联说明

在某些情况下,能够在调用点内联函数在Numba IR表示级别上是有用的。诸如 numba.jit()numba.extending.overload()register_jitable() 等装饰器支持关键字参数 inline,以促进这种行为。

在尝试在此级别内联时,重要的是要理解这服务的目的是什么以及这会产生什么效果。与LLVM旨在提高性能的内联不同,在Numba IR级别内联的主要原因是允许类型推断跨越函数边界。

作为一个例子,考虑以下代码片段:

from numba import njit


@njit
def bar(a):
    a.append(10)


@njit
def foo():
    z = []
    bar(z)


foo()

这将无法编译和运行,因为 z 的类型无法推断,因为它只会在 bar 内部被细化。如果我们现在在 bar 的装饰器中添加 inline=True,代码片段将编译并运行。这是因为内联调用 a.append(10) 意味着 z 将被细化为持有整数,因此类型推断将成功。

因此,总结一下,在 Numba IR 级别内联不太可能带来性能提升。而在 LLVM 级别内联则更有可能。

inline 关键字参数可以是以下三个值之一:

  • 字符串 'never' 是默认值,在任何情况下都不会内联该函数。

  • 字符串 'always',这会导致函数在所有调用点内联。

  • 一个接受三个参数的Python函数。第一个参数始终是请求内联的 callir.Expr 节点,这存在是为了允许函数根据调用上下文做出决策。第二个和第三个参数是:

    • 在无类型内联的情况下,即在使用 numba.jit() 系列装饰器时,两个参数都是 numba.ir.FunctionIR 实例。第二个参数对应于调用者的IR,第三个参数对应于被调用者的IR。

    • 在类型化内联的情况下,即在使用 numba.extending.overload() 时发生的情况,两个参数都是 namedtuple 的实例,具有字段(对应于它们在编译器内部的标准使用):

      • func_ir - 函数的 Numba IR。

      • typemap - 函数的类型映射。

      • calltypes - 函数中任何调用的调用类型。

      • signature - 函数的签名。

      第二个参数包含调用者的信息,第三个参数包含被调用者的信息。

    在所有情况下,函数应返回 True 以内联,返回 False 以不内联,这本质上允许自定义内联规则(典型用途可能是成本模型)。

  • 使用 inline='always' 的递归函数将导致编译无法终止。如果希望避免这种情况,请提供一个函数来限制递归深度(见下文)。

备注

关于函数被评估内联的顺序或它们被内联的顺序,没有任何保证。

使用 numba.jit() 的示例

numba.njit() 装饰器中使用所有三个选项的 inline 示例:

from numba import njit
import numba
from numba.core import ir


@njit(inline='never')
def never_inline():
    return 100


@njit(inline='always')
def always_inline():
    return 200


def sentinel_cost_model(expr, caller_info, callee_info):
    # this cost model will return True (i.e. do inlining) if either:
    # a) the callee IR contains an `ir.Const(37)`
    # b) the caller IR contains an `ir.Const(13)` logically prior to the call
    #    site

    # check the callee
    for blk in callee_info.blocks.values():
        for stmt in blk.body:
            if isinstance(stmt, ir.Assign):
                if isinstance(stmt.value, ir.Const):
                    if stmt.value.value == 37:
                        return True

    # check the caller
    before_expr = True
    for blk in caller_info.blocks.values():
        for stmt in blk.body:
            if isinstance(stmt, ir.Assign):
                if isinstance(stmt.value, ir.Expr):
                    if stmt.value == expr:
                        before_expr = False
                if isinstance(stmt.value, ir.Const):
                    if stmt.value.value == 13:
                        return True & before_expr
    return False


@njit(inline=sentinel_cost_model)
def maybe_inline1():
    # Will not inline based on the callee IR with the declared cost model
    # The following is ir.Const(300).
    return 300


@njit(inline=sentinel_cost_model)
def maybe_inline2():
    # Will inline based on the callee IR with the declared cost model
    # The following is ir.Const(37).
    return 37


@njit
def foo():
    a = never_inline()  # will never inline
    b = always_inline()  # will always inline

    # will not inline as the function does not contain a magic constant known to
    # the cost model, and the IR up to the call site does not contain a magic
    # constant either
    d = maybe_inline1()

    # declare this magic constant to trigger inlining of maybe_inline1 in a
    # subsequent call
    magic_const = 13

    # will inline due to above constant declaration
    e = maybe_inline1()

    # will inline as the maybe_inline2 function contains a magic constant known
    # to the cost model
    c = maybe_inline2()

    return a + b + c + d + e + magic_const


foo()

执行时生成以下内容(通过环境变量 NUMBA_DEBUG_PRINT_AFTER="ir_legalization" 启用,打印法律化后的IR):

label 0:
    $0.1 = global(never_inline: CPUDispatcher(<function never_inline at 0x7f890ccf9048>)) ['$0.1']
    $0.2 = call $0.1(func=$0.1, args=[], kws=(), vararg=None) ['$0.1', '$0.2']
    del $0.1                                 []
    a = $0.2                                 ['$0.2', 'a']
    del $0.2                                 []
    $0.3 = global(always_inline: CPUDispatcher(<function always_inline at 0x7f890ccf9598>)) ['$0.3']
    del $0.3                                 []
    $const0.1.0 = const(int, 200)            ['$const0.1.0']
    $0.2.1 = $const0.1.0                     ['$0.2.1', '$const0.1.0']
    del $const0.1.0                          []
    $0.4 = $0.2.1                            ['$0.2.1', '$0.4']
    del $0.2.1                               []
    b = $0.4                                 ['$0.4', 'b']
    del $0.4                                 []
    $0.5 = global(maybe_inline1: CPUDispatcher(<function maybe_inline1 at 0x7f890ccf9ae8>)) ['$0.5']
    $0.6 = call $0.5(func=$0.5, args=[], kws=(), vararg=None) ['$0.5', '$0.6']
    del $0.5                                 []
    d = $0.6                                 ['$0.6', 'd']
    del $0.6                                 []
    $const0.7 = const(int, 13)               ['$const0.7']
    magic_const = $const0.7                  ['$const0.7', 'magic_const']
    del $const0.7                            []
    $0.8 = global(maybe_inline1: CPUDispatcher(<function maybe_inline1 at 0x7f890ccf9ae8>)) ['$0.8']
    del $0.8                                 []
    $const0.1.2 = const(int, 300)            ['$const0.1.2']
    $0.2.3 = $const0.1.2                     ['$0.2.3', '$const0.1.2']
    del $const0.1.2                          []
    $0.9 = $0.2.3                            ['$0.2.3', '$0.9']
    del $0.2.3                               []
    e = $0.9                                 ['$0.9', 'e']
    del $0.9                                 []
    $0.10 = global(maybe_inline2: CPUDispatcher(<function maybe_inline2 at 0x7f890ccf9b70>)) ['$0.10']
    del $0.10                                []
    $const0.1.4 = const(int, 37)             ['$const0.1.4']
    $0.2.5 = $const0.1.4                     ['$0.2.5', '$const0.1.4']
    del $const0.1.4                          []
    $0.11 = $0.2.5                           ['$0.11', '$0.2.5']
    del $0.2.5                               []
    c = $0.11                                ['$0.11', 'c']
    del $0.11                                []
    $0.14 = a + b                            ['$0.14', 'a', 'b']
    del b                                    []
    del a                                    []
    $0.16 = $0.14 + c                        ['$0.14', '$0.16', 'c']
    del c                                    []
    del $0.14                                []
    $0.18 = $0.16 + d                        ['$0.16', '$0.18', 'd']
    del d                                    []
    del $0.16                                []
    $0.20 = $0.18 + e                        ['$0.18', '$0.20', 'e']
    del e                                    []
    del $0.18                                []
    $0.22 = $0.20 + magic_const              ['$0.20', '$0.22', 'magic_const']
    del magic_const                          []
    del $0.20                                []
    $0.23 = cast(value=$0.22)                ['$0.22', '$0.23']
    del $0.22                                []
    return $0.23                             ['$0.23']

需要注意的事项:

  1. 对函数 never_inline 的调用保持为调用。

  2. always_inline 函数已被内联,注意其在调用者主体中的 const(int, 200)

  3. const(int, 13) 声明之前有一个对 maybe_inline1 的调用,成本模型阻止了这次内联。

  4. const(int, 13) 之后,对 maybe_inline1 的后续调用已被内联,如调用者主体中的 const(int, 300) 所示。

  5. 函数 maybe_inline2 已被内联,如调用者主体中的 const(int, 37) 所示。

  6. 死代码消除尚未执行,因此IR中存在多余的语句。

使用示例 numba.extending.overload()

使用 numba.extending.overload() 装饰器进行内联的示例。值得注意的是,如果将函数作为 inline 的参数提供,则可以通过提供的函数参数获得更多信息,以便用于决策。此外,不同的 @overload 可以有不同的内联行为,有多种方法可以实现这一点:

import numba
from numba.extending import overload
from numba import njit, types


def bar(x):
    """A function stub to overload"""
    pass


@overload(bar, inline='always')
def ol_bar_tuple(x):
    # An overload that will always inline, there is a type guard so that this
    # only applies to UniTuples.
    if isinstance(x, types.UniTuple):
        def impl(x):
            return x[0]
        return impl


def cost_model(expr, caller, callee):
    # Only inline if the type of the argument is an Integer
    return isinstance(caller.typemap[expr.args[0].name], types.Integer)


@overload(bar, inline=cost_model)
def ol_bar_scalar(x):
    # An overload that will inline based on a cost model, it only applies to
    # scalar values in the numerical domain as per the type guard on Number
    if isinstance(x, types.Number):
        def impl(x):
            return x + 1
        return impl


@njit
def foo():

    # This will resolve via `ol_bar_tuple` as the argument is a types.UniTuple
    # instance. It will always be inlined as specified in the decorator for this
    # overload.
    a = bar((1, 2, 3))

    # This will resolve via `ol_bar_scalar` as the argument is a types.Number
    # instance, hence the cost_model will be used to determine whether to
    # inline.
    # The function will be inlined as the value 100 is an IntegerLiteral which
    # is an instance of a types.Integer as required by the cost_model function.
    b = bar(100)

    # This will also resolve via `ol_bar_scalar` as the argument is a
    # types.Number instance, again the cost_model will be used to determine
    # whether to inline.
    # The function will not be inlined as the complex value is not an instance
    # of a types.Integer as required by the cost_model function.
    c = bar(300j)

    return a + b + c


foo()

执行时生成以下内容(通过环境变量 NUMBA_DEBUG_PRINT_AFTER="ir_legalization" 启用,打印法律化后的IR):

label 0:
    $const0.2 = const(tuple, (1, 2, 3))      ['$const0.2']
    x.0 = $const0.2                          ['$const0.2', 'x.0']
    del $const0.2                            []
    $const0.2.2 = const(int, 0)              ['$const0.2.2']
    $0.3.3 = getitem(value=x.0, index=$const0.2.2) ['$0.3.3', '$const0.2.2', 'x.0']
    del x.0                                  []
    del $const0.2.2                          []
    $0.4.4 = $0.3.3                          ['$0.3.3', '$0.4.4']
    del $0.3.3                               []
    $0.3 = $0.4.4                            ['$0.3', '$0.4.4']
    del $0.4.4                               []
    a = $0.3                                 ['$0.3', 'a']
    del $0.3                                 []
    $const0.5 = const(int, 100)              ['$const0.5']
    x.5 = $const0.5                          ['$const0.5', 'x.5']
    del $const0.5                            []
    $const0.2.7 = const(int, 1)              ['$const0.2.7']
    $0.3.8 = x.5 + $const0.2.7               ['$0.3.8', '$const0.2.7', 'x.5']
    del x.5                                  []
    del $const0.2.7                          []
    $0.4.9 = $0.3.8                          ['$0.3.8', '$0.4.9']
    del $0.3.8                               []
    $0.6 = $0.4.9                            ['$0.4.9', '$0.6']
    del $0.4.9                               []
    b = $0.6                                 ['$0.6', 'b']
    del $0.6                                 []
    $0.7 = global(bar: <function bar at 0x7f6c3710d268>) ['$0.7']
    $const0.8 = const(complex, 300j)         ['$const0.8']
    $0.9 = call $0.7($const0.8, func=$0.7, args=[Var($const0.8, inline_overload_example.py (56))], kws=(), vararg=None) ['$0.7', '$0.9', '$const0.8']
    del $const0.8                            []
    del $0.7                                 []
    c = $0.9                                 ['$0.9', 'c']
    del $0.9                                 []
    $0.12 = a + b                            ['$0.12', 'a', 'b']
    del b                                    []
    del a                                    []
    $0.14 = $0.12 + c                        ['$0.12', '$0.14', 'c']
    del c                                    []
    del $0.12                                []
    $0.15 = cast(value=$0.14)                ['$0.14', '$0.15']
    del $0.14                                []
    return $0.15                             ['$0.15']

需要注意的事项:

  1. 第一个高亮部分是 UniTuple 参数类型的总是内联的重载。

  2. 第二个高亮部分是 Number 参数类型的重载,由于成本模型函数决定这样做,因为参数是一个 Integer 类型的实例,所以它被内联了。

  3. 第三个高亮部分是 Number 参数类型的重载,由于成本模型函数决定拒绝它,因为参数是一个 Complex 类型实例,所以没有内联。

  4. 死代码消除尚未执行,因此IR中存在多余的语句。

使用一个函数来限制递归函数的内联深度

在使用递归内联时,你可以通过使用成本模型来终止编译。

from numba import njit
import numpy as np

class CostModel(object):
    def __init__(self, max_inlines):
        self._count = 0
        self._max_inlines = max_inlines

    def __call__(self, expr, caller, callee):
        ret = self._count < self._max_inlines
        self._count += 1
        return ret

@njit(inline=CostModel(3))
def factorial(n):
    if n <= 0:
        return 1
    return n * factorial(n - 1)

factorial(5)