高级扩展API
此扩展API通过 numba.extending
模块暴露。
为了帮助调试Numba的扩展,建议设置以下环境变量:
NUMBA_CAPTURED_ERRORS="new_style"
这使得区分实现中的错误和可接受的错误变得容易,例如可以参与类型推断的错误。更多信息请参见 NUMBA_CAPTURED_ERRORS
。
实现功能
@overload
装饰器允许你实现任意函数,以便在 nopython 模式 函数中使用。被 @overload
装饰的函数在编译时会以函数运行时参数的 类型 调用。它应返回一个可调用对象,表示给定类型的函数的 实现。返回的实现会被 Numba 编译,就像它是一个用 @jit
装饰的普通函数一样。可以通过 jit_options
参数传递字典形式的 @jit
的额外选项。
例如,假设 Numba 还不支持元组上的 len()
函数。以下是如何使用 @overload
实现它:
from numba import types
from numba.extending import overload
@overload(len)
def tuple_len(seq):
if isinstance(seq, types.BaseTuple):
n = len(seq)
def len_impl(seq):
return n
return len_impl
你可能会好奇,如果 len()
被调用时传入的不是元组会发生什么?如果一个用 @overload
装饰的函数没有返回任何东西(即返回 None),则会尝试其他定义,直到其中一个成功。因此,多个库可以为不同类型重载 len()
而不会相互冲突。
实现方法
@overload_method
装饰器同样允许在 Numba 中实现一个众所周知的类型的方法。
- numba.core.extending.overload_method(typ, attr, **kwargs)[源代码]
一个装饰器,将装饰的函数标记为类型化,并在 nopython 模式下为给定的 Numba 类型实现方法 attr。
kwargs 被传递给底层的 @overload 调用。
以下是一个为数组类型实现 .take() 的示例:
@overload_method(types.Array, 'take') def array_take(arr, indices): if isinstance(indices, types.Array): def take_impl(arr, indices): n = indices.shape[0] res = np.empty(n, arr.dtype) for i in range(n): res[i] = arr[indices[i]] return res return take_impl
实现类方法
@overload_classmethod
装饰器同样允许在 Numba 中实现一个类型已知的类方法。
- numba.core.extending.overload_classmethod(typ, attr, **kwargs)[源代码]
一个装饰器,将装饰的函数标记为类型,并在 nopython 模式下为给定的 Numba 类型实现类方法 attr。
类似于
overload_method
。以下是一个在 Array 类型上实现 classmethod 以调用
np.arange()
的示例:@overload_classmethod(types.Array, "make") def ov_make(cls, nitems): def impl(cls, nitems): return np.arange(nitems) return impl
上述代码将允许以下内容在即时编译代码中工作:
@njit def foo(n): return types.Array.make(n)
实现属性
@overload_attribute
装饰器允许在类型上实现数据属性(或属性)。只能读取属性;可写属性仅通过 低级 API 支持。
以下示例实现了 Numpy 数组上的 nbytes
属性:
@overload_attribute(types.Array, 'nbytes')
def array_nbytes(arr):
def get(arr):
return arr.size * arr.itemsize
return get
导入 Cython 函数
函数 get_cython_function_address
获取 Cython 扩展模块中 C 函数的地址。该地址可以通过 ctypes.CFUNCTYPE()
回调来访问 C 函数,从而允许在 Numba jitted 函数中使用该 C 函数。例如,假设你有一个文件 foo.pyx
:
from libc.math cimport exp
cdef api double myexp(double x):
return exp(x)
你可以通过以下方式从 Numba 访问 myexp
:
import ctypes
from numba.extending import get_cython_function_address
addr = get_cython_function_address("foo", "myexp")
functype = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double)
myexp = functype(addr)
函数 myexp
现在可以在 jitted 函数内部使用,例如:
@njit
def double_myexp(x):
return 2*myexp(x)
需要注意的是,如果你的函数使用了 Cython 的融合类型,那么函数的名称将会被修改。要找出函数被修改后的名称,你可以检查扩展模块的 __pyx_capi__
属性。
实现内部函数
@intrinsic
装饰器用于将函数 func 标记为类型,并使用 llvmlite IRBuilder API 在 nopython
模式下实现该函数。这是为专家用户提供的一个逃生舱,用于构建将被内联到调用者中的自定义 LLVM IR,没有安全网!
第一个传递给 func 的参数是类型上下文。其余的参数对应于被装饰函数的参数类型。这些参数也被用作被装饰函数的正式参数。如果 func 的签名是 foo(typing_context, arg0, arg1)
,那么被装饰函数的签名将是 foo(arg0, arg1)
。
函数 func 的返回值应为一个 2-元组,包含预期的类型签名和一个代码生成函数,该函数将被传递给 lower_builtin()
。对于不支持的操作,返回 None
。
这是一个将任意整数转换为字节指针的示例:
from numba import types
from numba.extending import intrinsic
@intrinsic
def cast_int_to_byte_ptr(typingctx, src):
# check for accepted types
if isinstance(src, types.Integer):
# create the expected type signature
result_type = types.CPointer(types.uint8)
sig = result_type(types.uintp)
# defines the custom code generation
def codegen(context, builder, signature, args):
# llvm IRBuilder code here
[src] = args
rtype = signature.return_type
llrtype = context.get_value_type(rtype)
return builder.inttoptr(src, llrtype)
return sig, codegen
它可以如下使用:
from numba import njit
@njit('void(int64)')
def foo(x):
y = cast_int_to_byte_ptr(x)
foo.inspect_types()
而 .inspect_types()
的输出展示了这种转换(注意 uint8*
):
def foo(x):
# x = arg(0, name=x) :: int64
# $0.1 = global(cast_int_to_byte_ptr: <intrinsic cast_int_to_byte_ptr>) :: Function(<intrinsic cast_int_to_byte_ptr>)
# $0.3 = call $0.1(x, func=$0.1, args=[Var(x, check_intrin.py (24))], kws=(), vararg=None) :: (uint64,) -> uint8*
# del x
# del $0.1
# y = $0.3 :: uint8*
# del y
# del $0.3
# $const0.4 = const(NoneType, None) :: none
# $0.5 = cast(value=$const0.4) :: none
# del $const0.4
# return $0.5
y = cast_int_to_byte_ptr(x)
实现可变结构
警告
这是一个实验性功能,API 可能会在没有警告的情况下发生变化。
numba.experimental.structref
模块提供了用于定义可变引用传递结构的工具,即 StructRef
。以下示例展示了如何定义一个基本的可变结构:
定义一个 StructRef
1import numpy as np
2
3from numba import njit
4from numba.core import types
5from numba.experimental import structref
6
7from numba.tests.support import skip_unless_scipy
8
9
10# Define a StructRef.
11# `structref.register` associates the type with the default data model.
12# This will also install getters and setters to the fields of
13# the StructRef.
14@structref.register
15class MyStructType(types.StructRef):
16 def preprocess_fields(self, fields):
17 # This method is called by the type constructor for additional
18 # preprocessing on the fields.
19 # Here, we don't want the struct to take Literal types.
20 return tuple((name, types.unliteral(typ)) for name, typ in fields)
21
22
23# Define a Python type that can be use as a proxy to the StructRef
24# allocated inside Numba. Users can construct the StructRef via
25# the constructor for this type in python code and jit-code.
26class MyStruct(structref.StructRefProxy):
27 def __new__(cls, name, vector):
28 # Overriding the __new__ method is optional, doing so
29 # allows Python code to use keyword arguments,
30 # or add other customized behavior.
31 # The default __new__ takes `*args`.
32 # IMPORTANT: Users should not override __init__.
33 return structref.StructRefProxy.__new__(cls, name, vector)
34
35 # By default, the proxy type does not reflect the attributes or
36 # methods to the Python side. It is up to users to define
37 # these. (This may be automated in the future.)
38
39 @property
40 def name(self):
41 # To access a field, we can define a function that simply
42 # return the field in jit-code.
43 # The definition of MyStruct_get_name is shown later.
44 return MyStruct_get_name(self)
45
46 @property
47 def vector(self):
48 # The definition of MyStruct_get_vector is shown later.
49 return MyStruct_get_vector(self)
50
51
52@njit
53def MyStruct_get_name(self):
54 # In jit-code, the StructRef's attribute is exposed via
55 # structref.register
56 return self.name
57
58
59@njit
60def MyStruct_get_vector(self):
61 return self.vector
62
63
64# This associates the proxy with MyStructType for the given set of
65# fields. Notice how we are not constraining the type of each field.
66# Field types remain generic.
67structref.define_proxy(MyStruct, MyStructType, ["name", "vector"])
以下演示了使用上述可变结构体定义:
1# Let's test our new StructRef.
2
3# Define one in Python
4alice = MyStruct("Alice", vector=np.random.random(3))
5
6# Define one in jit-code
7@njit
8def make_bob():
9 bob = MyStruct("unnamed", vector=np.zeros(3))
10 # Mutate the attributes
11 bob.name = "Bob"
12 bob.vector = np.random.random(3)
13 return bob
14
15bob = make_bob()
16
17# Out: Alice: [0.5488135 0.71518937 0.60276338]
18print(f"{alice.name}: {alice.vector}")
19# Out: Bob: [0.88325739 0.73527629 0.87746707]
20print(f"{bob.name}: {bob.vector}")
21
22# Define a jit function to operate on the structs.
23@njit
24def distance(a, b):
25 return np.linalg.norm(a.vector - b.vector)
26
27# Out: 0.4332647200356598
28print(distance(alice, bob))
在 StructRef 上定义一个方法
方法和属性可以使用 @overload_*
附加,如前几节所示。
以下演示了如何使用 @overload_method
为 MyStructType
的实例插入一个方法:
1from numba.core.extending import overload_method
2from numba.core.errors import TypingError
3
4# Use @overload_method to add a method for
5# MyStructType.distance(other)
6# where *other* is an instance of MyStructType.
7@overload_method(MyStructType, "distance")
8def ol_distance(self, other):
9 # Guard that *other* is an instance of MyStructType
10 if not isinstance(other, MyStructType):
11 raise TypingError(
12 f"*other* must be a {MyStructType}; got {other}"
13 )
14
15 def impl(self, other):
16 return np.linalg.norm(self.vector - other.vector)
17
18 return impl
19
20# Test
21@njit
22def test():
23 alice = MyStruct("Alice", vector=np.random.random(3))
24 bob = MyStruct("Bob", vector=np.random.random(3))
25 # Use the method
26 return alice.distance(bob)
numba.experimental.structref
API 参考
用于定义可变结构的工具。
可变结构体通过引用传递;因此,structref(对结构体的引用)。
- class numba.experimental.structref.StructRefProxy(*args)[源代码]
一个指向 Numba 分配的 structref 数据结构的 PyObject 代理。
注释
子类不应定义
__init__
。子类可以重写
__new__
。
- numba.experimental.structref.define_attributes(struct_typeclass)[源代码]
在 struct_typeclass 上定义属性。
在 jit-code 中定义了设置器和获取器。
这直接在 register() 中调用。
- numba.experimental.structref.define_boxing(struct_type, obj_class)[源代码]
为 struct_type 定义装箱和拆箱逻辑到 obj_class。
定义了装箱和拆箱。
boxing 将 struct_type 的实例转换为 obj_class 的 PyObject
解包将 obj_class 的实例在 jit-code 中转换为 struct_type 的实例。
当用户不希望定义任何构造函数时,直接使用此方法代替 define_proxy()。
- numba.experimental.structref.define_constructor(py_class, struct_typeclass, fields)[源代码]
使用 Python 类型 py_class 和所需的 fields 定义 struct_typeclass 的 jit-code 构造函数。
如果用户不希望定义装箱逻辑,请使用此方法代替 define_proxy()。
- numba.experimental.structref.define_proxy(py_class, struct_typeclass, fields)[源代码]
为 structref 定义一个 PyObject 代理。
这使得 py_class 成为一个有效的构造函数,用于创建包含 fields 定义的成员的 struct_typeclass 实例。
- 参数:
- py_class类型
用于构造 struct_typeclass 实例的 Python 类。
- struct_typeclassnumba.core.types.Type
要绑定的 structref 类型类。
- 字段Sequence[str]
一系列字段名称。
- 返回:
- 无
- numba.experimental.structref.register(struct_type)[源代码]
在jit代码中注册一个 numba.core.types.StructRef 以供使用。
这定义了降低 struct_type 实例的数据模型。这定义了 struct_type 实例的属性访问器和修改器。
- 参数:
- struct_type类型
numba.core.types.StructRef 的子类。
- 返回:
- struct_type类型
返回输入参数,因此这可以作为装饰器使用。
示例
class MyStruct(numba.core.types.StructRef): ... # the simplest subclass can be empty numba.experimental.structref.register(MyStruct)
确定一个函数是否已经被 jit
家族的装饰器包装
以下函数为此目的提供。
- extending.is_jitted()
如果一个函数被 Numba 的 @jit 装饰器之一包装,则返回 True,例如:numba.jit, numba.njit
此函数的目的在于提供一种方法来检查函数是否已经进行了JIT装饰。