使用 @overload 的指南

高层扩展API 中所述,您可以使用 @overload 装饰器来创建一个可以在 nopython模式 函数中使用的Numba实现。一个常见的用例是重新实现NumPy函数,以便它们可以在 @jit 装饰的代码中调用。本节讨论何时以及如何使用 @overload 装饰器,以及向Numba代码库贡献这样一个函数可能涉及的内容。这应该有助于您在使用 @overload 装饰器或尝试向Numba本身贡献新函数时开始。

@overload 装饰器及其变体在你有一个不受控制的第三方库,并且你希望为该库中的特定函数提供Numba兼容的实现时非常有用。

具体示例

让我们假设你正在研究一个最小化算法,该算法使用 scipy.linalg.norm 来寻找不同的向量范数和矩阵的 frobenius 范数。你知道只会涉及整数和实数。(虽然这听起来可能像一个人工的例子,特别是由于存在 numpy.linalg.norm 的 Numba 实现,但它主要是教学性的,用于说明何时以及如何使用 @overload)。

骨架可能看起来像这样:

def algorithm():
    # setup
    v = ...
    while True:
        # take a step
        d = scipy.linalg.norm(v)
        if d < tolerance:
            break

现在,让我们进一步假设,你已经听说过 Numba,并且你现在希望使用它来加速你的函数。然而,在添加了 jit(nopython=True) 装饰器后,Numba 抱怨 scipy.linalg.norm 不受支持。通过查看文档,你意识到使用 NumPy 实现范数可能相当简单。一个好的起点是以下模板。

# Declare that function `myfunc` is going to be overloaded (have a
# substitutable Numba implementation)
@overload(myfunc)
# Define the overload function with formal arguments
# these arguments must be matched in the inner function implementation
def jit_myfunc(arg0, arg1, arg2, ...):
    # This scope is for typing, access is available to the *type* of all
    # arguments. This information can be used to change the behaviour of the
    # implementing function and check that the types are actually supported
    # by the implementation.

    print(arg0) # this will show the Numba type of arg0

    # This is the definition of the function that implements the `myfunc` work.
    # It does whatever algorithm is needed to implement myfunc.
    def myfunc_impl(arg0, arg1, arg2, ...): # match arguments to jit_myfunc
        # < Implementation goes here >
        return # whatever needs to be returned by the algorithm

    # return the implementation
    return myfunc_impl

经过一番思考和调整,你最终得到了以下代码:

import numpy as np
from numba import njit, types
from numba.extending import overload, register_jitable
from numba.core.errors import TypingError

import scipy.linalg


@register_jitable
def _oneD_norm_2(a):
    # re-usable implementation of the 2-norm
    val = np.abs(a)
    return np.sqrt(np.sum(val * val))


@overload(scipy.linalg.norm)
def jit_norm(a, ord=None):
    if isinstance(ord, types.Optional):
        ord = ord.type
    # Reject non integer, floating-point or None types for ord
    if not isinstance(ord, (types.Integer, types.Float, types.NoneType)):
        raise TypingError("'ord' must be either integer or floating-point")
    # Reject non-ndarray types
    if not isinstance(a, types.Array):
        raise TypingError("Only accepts NumPy ndarray")
    # Reject ndarrays with non integer or floating-point dtype
    if not isinstance(a.dtype, (types.Integer, types.Float)):
        raise TypingError("Only integer and floating point types accepted")
    # Reject ndarrays with unsupported dimensionality
    if not (0 <= a.ndim <= 2):
        raise TypingError('3D and beyond are not allowed')
    # Implementation for scalars/0d-arrays
    elif a.ndim == 0:
        return a.item()
    # Implementation for vectors
    elif a.ndim == 1:
        def _oneD_norm_x(a, ord=None):
            if ord == 2 or ord is None:
                return _oneD_norm_2(a)
            elif ord == np.inf:
                return np.max(np.abs(a))
            elif ord == -np.inf:
                return np.min(np.abs(a))
            elif ord == 0:
                return np.sum(a != 0)
            elif ord == 1:
                return np.sum(np.abs(a))
            else:
                return np.sum(np.abs(a)**ord)**(1. / ord)
        return _oneD_norm_x
    # Implementation for matrices
    elif a.ndim == 2:
        def _two_D_norm_2(a, ord=None):
            return _oneD_norm_2(a.ravel())
        return _two_D_norm_2


if __name__ == "__main__":
    @njit
    def use(a, ord=None):
        # simple test function to check that the overload works
        return scipy.linalg.norm(a, ord)

    # spot check for vectors
    a = np.arange(10)
    print(use(a))
    print(scipy.linalg.norm(a))

    # spot check for matrices
    b = np.arange(9).reshape((3, 3))
    print(use(b))
    print(scipy.linalg.norm(b))

如你所见,该实现仅支持你目前所需的功能:

  • 仅支持整数和浮点类型

  • 所有向量范数

  • 仅适用于矩阵的Frobenius范数

  • 使用 @register_jitable 在向量和矩阵实现之间共享代码。

  • 规范使用 NumPy 语法实现。(这是可能的,因为 Numba 非常了解 NumPy,并且支持许多函数。)

那么这里实际上发生了什么?overload 装饰器为 scipy.linalg.norm 注册了一个合适的实现,以防在代码中遇到对该函数的调用,例如当你用 @jit(nopython=True) 装饰你的 algorithm 函数时。在这种情况下,函数 jit_norm 将被调用,并传入当前遇到的类型,然后根据情况返回 _oneD_norm_x``(在向量情况下)或 ``_two_D_norm_2

你可以在这里下载示例代码:mynorm.py

为NumPy函数实现``@overload``

Numba 通过提供与 @jit 兼容的 NumPy 函数重新实现来支持 NumPy。在这种情况下,@overload 是一个非常方便的选项来编写这些实现,但还有一些额外的注意事项需要注意。

  • Numba 实现应尽可能与 NumPy 实现保持一致,包括接受的类型、参数、引发的异常以及算法复杂度(Big-O / Landau 阶数)。

  • 在实现支持的参数类型时,请记住,由于鸭子类型,NumPy确实倾向于接受多种参数类型,而不仅仅是NumPy数组,例如标量、列表、元组、集合、迭代器、生成器等。在类型推断期间以及随后作为测试的一部分时,您需要考虑到这一点。

  • 一个NumPy函数可能返回一个标量、数组或与其输入之一匹配的数据结构,你需要意识到类型统一问题并分派到适当的实现。例如,np.corrcoef 可能根据其输入返回一个数组或一个标量。

  • 如果你正在实现一个新功能,你应该始终更新 文档 。源文件可以在 docs/source/reference/numpysupported.rst 中找到。务必提及你的实现存在的任何限制,例如不支持 axis 关键字。

  • 在为功能本身编写测试时,包含对非有限值、不同形状和布局的数组、复杂输入、标量输入、未记录支持类型的输入(例如,NumPy文档中说明需要浮点数或整数输入的函数,如果给定布尔值或复数输入,也可能’有效’)的处理是有用的。

  • 在为异常编写测试时,例如在向 numba/tests/test_np_functions.py 添加测试时,您可能会遇到以下错误消息:

    ======================================================================
    FAIL: test_foo (numba.tests.test_np_functions.TestNPFunctions)
    ----------------------------------------------------------------------
    Traceback (most recent call last):
    File "<path>/numba/numba/tests/support.py", line 645, in tearDown
        self.memory_leak_teardown()
    File "<path>/numba/numba/tests/support.py", line 619, in memory_leak_teardown
        self.assert_no_memory_leak()
    File "<path>/numba/numba/tests/support.py", line 628, in assert_no_memory_leak
        self.assertEqual(total_alloc, total_free)
    AssertionError: 36 != 35
    

    这是因为从 jitted 代码中引发异常会导致引用泄漏。理想情况下,您会将所有异常测试放在一个单独的测试方法中,然后在每个测试中添加对 self.disable_leak_check() 的调用来禁用泄漏检查(继承自 numba.tests.support.TestCase 以使其可用)。

  • 对于NumPy中许多可用的函数,在NumPy的``ndarray``类型上定义了相应的方法。例如,函数``repeat``既可以作为NumPy模块级别的函数使用,也可以作为``ndarray``类的成员函数使用。

    import numpy as np
    a = np.arange(10)
    # function
    np.repeat(a, 10)
    # method
    a.repeat(10)
    

    一旦你完成了函数实现,你可以轻松使用 @overload_method 并重用它。只需确保检查 NumPy 在其函数/方法实现中没有偏离。

    作为一个例子,repeat 函数/方法:

    @extending.overload_method(types.Array, 'repeat')
    def array_repeat(a, repeats):
        def array_repeat_impl(a, repeat):
            # np.repeat has already been overloaded
            return np.repeat(a, repeat)
    
        return array_repeat_impl
    
  • 如果你需要创建辅助函数,例如为了重用一个小型实用函数或为了可读性而将实现拆分为多个函数,你可以使用 @register_jitable 装饰器。这将使这些函数在你的 @jit@overload 装饰函数中可用。

  • Numba 的持续集成 (CI) 设置测试了多种 NumPy 版本,有时你会被提醒到某个之前的 NumPy 版本的行为发生了变化。如果你能在 NumPy 的变更日志 / 仓库中找到支持的证据,那么你需要决定是创建分支并尝试在不同版本间复制逻辑,还是使用版本门控(并在文档中附带相关说明)来告知用户 Numba 从某个特定版本开始复制 NumPy 的行为。

  • 你可以参考Numba的源代码以获取灵感,许多重载的NumPy函数和方法都在 numba/targets/arrayobj.py 中。下面,你将找到一个实现列表,这些实现在接受的类型和测试覆盖率方面都做得很好。

    • np.repeat