高级扩展API

此扩展API通过 numba.extending 模块暴露。

为了帮助调试Numba的扩展,建议设置以下环境变量:

NUMBA_CAPTURED_ERRORS="new_style"

这使得区分实现中的错误和可接受的错误变得容易,例如可以参与类型推断的错误。更多信息请参见 NUMBA_CAPTURED_ERRORS

实现功能

@overload 装饰器允许你实现任意函数,以便在 nopython 模式 函数中使用。被 @overload 装饰的函数在编译时会以函数运行时参数的 类型 调用。它应返回一个可调用对象,表示给定类型的函数的 实现。返回的实现会被 Numba 编译,就像它是一个用 @jit 装饰的普通函数一样。可以通过 jit_options 参数传递字典形式的 @jit 的额外选项。

例如,假设 Numba 还不支持元组上的 len() 函数。以下是如何使用 @overload 实现它:

from numba import types
from numba.extending import overload

@overload(len)
def tuple_len(seq):
   if isinstance(seq, types.BaseTuple):
       n = len(seq)
       def len_impl(seq):
           return n
       return len_impl

你可能会好奇,如果 len() 被调用时传入的不是元组会发生什么?如果一个用 @overload 装饰的函数没有返回任何东西(即返回 None),则会尝试其他定义,直到其中一个成功。因此,多个库可以为不同类型重载 len() 而不会相互冲突。

实现方法

@overload_method 装饰器同样允许在 Numba 中实现一个众所周知的类型的方法。

numba.core.extending.overload_method(typ, attr, **kwargs)[源代码]

一个装饰器,将装饰的函数标记为类型化,并在 nopython 模式下为给定的 Numba 类型实现方法 attr

kwargs 被传递给底层的 @overload 调用。

以下是一个为数组类型实现 .take() 的示例:

@overload_method(types.Array, 'take')
def array_take(arr, indices):
    if isinstance(indices, types.Array):
        def take_impl(arr, indices):
            n = indices.shape[0]
            res = np.empty(n, arr.dtype)
            for i in range(n):
                res[i] = arr[indices[i]]
            return res
        return take_impl

实现类方法

@overload_classmethod 装饰器同样允许在 Numba 中实现一个类型已知的类方法。

numba.core.extending.overload_classmethod(typ, attr, **kwargs)[源代码]

一个装饰器,将装饰的函数标记为类型,并在 nopython 模式下为给定的 Numba 类型实现类方法 attr

类似于 overload_method

以下是一个在 Array 类型上实现 classmethod 以调用 np.arange() 的示例:

@overload_classmethod(types.Array, "make")
def ov_make(cls, nitems):
    def impl(cls, nitems):
        return np.arange(nitems)
    return impl

上述代码将允许以下内容在即时编译代码中工作:

@njit
def foo(n):
    return types.Array.make(n)

实现属性

@overload_attribute 装饰器允许在类型上实现数据属性(或属性)。只能读取属性;可写属性仅通过 低级 API 支持。

以下示例实现了 Numpy 数组上的 nbytes 属性:

@overload_attribute(types.Array, 'nbytes')
def array_nbytes(arr):
   def get(arr):
       return arr.size * arr.itemsize
   return get

导入 Cython 函数

函数 get_cython_function_address 获取 Cython 扩展模块中 C 函数的地址。该地址可以通过 ctypes.CFUNCTYPE() 回调来访问 C 函数,从而允许在 Numba jitted 函数中使用该 C 函数。例如,假设你有一个文件 foo.pyx:

from libc.math cimport exp

cdef api double myexp(double x):
    return exp(x)

你可以通过以下方式从 Numba 访问 myexp:

import ctypes
from numba.extending import get_cython_function_address

addr = get_cython_function_address("foo", "myexp")
functype = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double)
myexp = functype(addr)

函数 myexp 现在可以在 jitted 函数内部使用,例如:

@njit
def double_myexp(x):
    return 2*myexp(x)

需要注意的是,如果你的函数使用了 Cython 的融合类型,那么函数的名称将会被修改。要找出函数被修改后的名称,你可以检查扩展模块的 __pyx_capi__ 属性。

实现内部函数

@intrinsic 装饰器用于将函数 func 标记为类型,并使用 llvmlite IRBuilder APInopython 模式下实现该函数。这是为专家用户提供的一个逃生舱,用于构建将被内联到调用者中的自定义 LLVM IR,没有安全网!

第一个传递给 func 的参数是类型上下文。其余的参数对应于被装饰函数的参数类型。这些参数也被用作被装饰函数的正式参数。如果 func 的签名是 foo(typing_context, arg0, arg1),那么被装饰函数的签名将是 foo(arg0, arg1)

函数 func 的返回值应为一个 2-元组,包含预期的类型签名和一个代码生成函数,该函数将被传递给 lower_builtin()。对于不支持的操作,返回 None

这是一个将任意整数转换为字节指针的示例:

from numba import types
from numba.extending import intrinsic

@intrinsic
def cast_int_to_byte_ptr(typingctx, src):
    # check for accepted types
    if isinstance(src, types.Integer):
        # create the expected type signature
        result_type = types.CPointer(types.uint8)
        sig = result_type(types.uintp)
        # defines the custom code generation
        def codegen(context, builder, signature, args):
            # llvm IRBuilder code here
            [src] = args
            rtype = signature.return_type
            llrtype = context.get_value_type(rtype)
            return builder.inttoptr(src, llrtype)
        return sig, codegen

它可以如下使用:

from numba import njit

@njit('void(int64)')
def foo(x):
    y = cast_int_to_byte_ptr(x)

foo.inspect_types()

.inspect_types() 的输出展示了这种转换(注意 uint8*):

def foo(x):

    #   x = arg(0, name=x)  :: int64
    #   $0.1 = global(cast_int_to_byte_ptr: <intrinsic cast_int_to_byte_ptr>)  :: Function(<intrinsic cast_int_to_byte_ptr>)
    #   $0.3 = call $0.1(x, func=$0.1, args=[Var(x, check_intrin.py (24))], kws=(), vararg=None)  :: (uint64,) -> uint8*
    #   del x
    #   del $0.1
    #   y = $0.3  :: uint8*
    #   del y
    #   del $0.3
    #   $const0.4 = const(NoneType, None)  :: none
    #   $0.5 = cast(value=$const0.4)  :: none
    #   del $const0.4
    #   return $0.5

    y = cast_int_to_byte_ptr(x)

实现可变结构

警告

这是一个实验性功能,API 可能会在没有警告的情况下发生变化。

numba.experimental.structref 模块提供了用于定义可变引用传递结构的工具,即 StructRef。以下示例展示了如何定义一个基本的可变结构:

定义一个 StructRef

来自 numba/tests/doc_examples/test_structref_usage.py
 1import numpy as np
 2
 3from numba import njit
 4from numba.core import types
 5from numba.experimental import structref
 6
 7from numba.tests.support import skip_unless_scipy
 8
 9
10# Define a StructRef.
11# `structref.register` associates the type with the default data model.
12# This will also install getters and setters to the fields of
13# the StructRef.
14@structref.register
15class MyStructType(types.StructRef):
16    def preprocess_fields(self, fields):
17        # This method is called by the type constructor for additional
18        # preprocessing on the fields.
19        # Here, we don't want the struct to take Literal types.
20        return tuple((name, types.unliteral(typ)) for name, typ in fields)
21
22
23# Define a Python type that can be use as a proxy to the StructRef
24# allocated inside Numba. Users can construct the StructRef via
25# the constructor for this type in python code and jit-code.
26class MyStruct(structref.StructRefProxy):
27    def __new__(cls, name, vector):
28        # Overriding the __new__ method is optional, doing so
29        # allows Python code to use keyword arguments,
30        # or add other customized behavior.
31        # The default __new__ takes `*args`.
32        # IMPORTANT: Users should not override __init__.
33        return structref.StructRefProxy.__new__(cls, name, vector)
34
35    # By default, the proxy type does not reflect the attributes or
36    # methods to the Python side. It is up to users to define
37    # these. (This may be automated in the future.)
38
39    @property
40    def name(self):
41        # To access a field, we can define a function that simply
42        # return the field in jit-code.
43        # The definition of MyStruct_get_name is shown later.
44        return MyStruct_get_name(self)
45
46    @property
47    def vector(self):
48        # The definition of MyStruct_get_vector is shown later.
49        return MyStruct_get_vector(self)
50
51
52@njit
53def MyStruct_get_name(self):
54    # In jit-code, the StructRef's attribute is exposed via
55    # structref.register
56    return self.name
57
58
59@njit
60def MyStruct_get_vector(self):
61    return self.vector
62
63
64# This associates the proxy with MyStructType for the given set of
65# fields. Notice how we are not constraining the type of each field.
66# Field types remain generic.
67structref.define_proxy(MyStruct, MyStructType, ["name", "vector"])

以下演示了使用上述可变结构体定义:

来自 numba/tests/doc_examples/test_structref_usage.py 中的 test_type_definition
 1# Let's test our new StructRef.
 2
 3# Define one in Python
 4alice = MyStruct("Alice", vector=np.random.random(3))
 5
 6# Define one in jit-code
 7@njit
 8def make_bob():
 9    bob = MyStruct("unnamed", vector=np.zeros(3))
10    # Mutate the attributes
11    bob.name = "Bob"
12    bob.vector = np.random.random(3)
13    return bob
14
15bob = make_bob()
16
17# Out: Alice: [0.5488135  0.71518937 0.60276338]
18print(f"{alice.name}: {alice.vector}")
19# Out: Bob: [0.88325739 0.73527629 0.87746707]
20print(f"{bob.name}: {bob.vector}")
21
22# Define a jit function to operate on the structs.
23@njit
24def distance(a, b):
25    return np.linalg.norm(a.vector - b.vector)
26
27# Out: 0.4332647200356598
28print(distance(alice, bob))

在 StructRef 上定义一个方法

方法和属性可以使用 @overload_* 附加,如前几节所示。

以下演示了如何使用 @overload_methodMyStructType 的实例插入一个方法:

来自 numba/tests/doc_examples/test_structref_usage.py 中的 test_overload_method
 1from numba.core.extending import overload_method
 2from numba.core.errors import TypingError
 3
 4# Use @overload_method to add a method for
 5# MyStructType.distance(other)
 6# where *other* is an instance of MyStructType.
 7@overload_method(MyStructType, "distance")
 8def ol_distance(self, other):
 9    # Guard that *other* is an instance of MyStructType
10    if not isinstance(other, MyStructType):
11        raise TypingError(
12            f"*other* must be a {MyStructType}; got {other}"
13        )
14
15    def impl(self, other):
16        return np.linalg.norm(self.vector - other.vector)
17
18    return impl
19
20# Test
21@njit
22def test():
23    alice = MyStruct("Alice", vector=np.random.random(3))
24    bob = MyStruct("Bob", vector=np.random.random(3))
25    # Use the method
26    return alice.distance(bob)

numba.experimental.structref API 参考

用于定义可变结构的工具。

可变结构体通过引用传递;因此,structref(对结构体的引用)。

class numba.experimental.structref.StructRefProxy(*args)[源代码]

一个指向 Numba 分配的 structref 数据结构的 PyObject 代理。

注释

  • 子类不应定义 __init__

  • 子类可以重写 __new__

numba.experimental.structref.define_attributes(struct_typeclass)[源代码]

struct_typeclass 上定义属性。

在 jit-code 中定义了设置器和获取器。

这直接在 register() 中调用。

numba.experimental.structref.define_boxing(struct_type, obj_class)[源代码]

struct_type 定义装箱和拆箱逻辑到 obj_class

定义了装箱和拆箱。

  • boxing 将 struct_type 的实例转换为 obj_class 的 PyObject

  • 解包将 obj_class 的实例在 jit-code 中转换为 struct_type 的实例。

当用户不希望定义任何构造函数时,直接使用此方法代替 define_proxy()

numba.experimental.structref.define_constructor(py_class, struct_typeclass, fields)[源代码]

使用 Python 类型 py_class 和所需的 fields 定义 struct_typeclass 的 jit-code 构造函数。

如果用户不希望定义装箱逻辑,请使用此方法代替 define_proxy()

numba.experimental.structref.define_proxy(py_class, struct_typeclass, fields)[源代码]

为 structref 定义一个 PyObject 代理。

这使得 py_class 成为一个有效的构造函数,用于创建包含 fields 定义的成员的 struct_typeclass 实例。

参数:
py_class类型

用于构造 struct_typeclass 实例的 Python 类。

struct_typeclassnumba.core.types.Type

要绑定的 structref 类型类。

字段Sequence[str]

一系列字段名称。

返回:
numba.experimental.structref.register(struct_type)[源代码]

在jit代码中注册一个 numba.core.types.StructRef 以供使用。

这定义了降低 struct_type 实例的数据模型。这定义了 struct_type 实例的属性访问器和修改器。

参数:
struct_type类型

numba.core.types.StructRef 的子类。

返回:
struct_type类型

返回输入参数,因此这可以作为装饰器使用。

示例

class MyStruct(numba.core.types.StructRef):
    ...  # the simplest subclass can be empty

numba.experimental.structref.register(MyStruct)

确定一个函数是否已经被 jit 家族的装饰器包装

以下函数为此目的提供。

extending.is_jitted()

如果一个函数被 Numba 的 @jit 装饰器之一包装,则返回 True,例如:numba.jit, numba.njit

此函数的目的在于提供一种方法来检查函数是否已经进行了JIT装饰。