示例:区间类型

在这个示例中,我们将扩展Numba前端,以添加对一个用户定义的类的支持,该类在内部不受支持。这将允许:

  • 将类的实例传递给Numba函数

  • 在 Numba 函数中访问类的属性

  • 从Numba函数中构造并返回类的新实例

(以上所有内容在 nopython 模式 下)

我们将根据特定任务的可用性,混合使用 高级扩展API低级扩展API 中的API。

我们示例的起点是以下纯Python类:

class Interval(object):
    """
    A half-open interval on the real number line.
    """
    def __init__(self, lo, hi):
        self.lo = lo
        self.hi = hi

    def __repr__(self):
        return 'Interval(%f, %f)' % (self.lo, self.hi)

    @property
    def width(self):
        return self.hi - self.lo

扩展类型层

创建一个新的 Numba 类型

由于 Interval 类对 Numba 来说是未知的,我们必须创建一个新的 Numba 类型来表示它的实例。Numba 不直接处理 Python 类型:它有自己的类型系统,允许不同级别的粒度以及常规 Python 类型中不可用的各种元信息。

我们首先创建一个类型类 IntervalType ,并且由于我们不需要该类型是参数化的,我们实例化一个单一的类型实例 interval_type

from numba import types

class IntervalType(types.Type):
    def __init__(self):
        super(IntervalType, self).__init__(name='Interval')

interval_type = IntervalType()

Python 值的类型推断

本身,创建一个 Numba 类型并不会做任何事情。我们必须教 Numba 如何将某些 Python 值推断为该类型的实例。在这个例子中,这是微不足道的:任何 Interval 类的实例都应该被视为属于 interval_type 类型:

from numba.extending import typeof_impl

@typeof_impl.register(Interval)
def typeof_index(val, c):
    return interval_type

函数参数和全局值将被识别为属于 interval_type,只要它们是 Interval 的实例。

Python 注解的类型推断

虽然 typeof 用于推断 Python 对象的 Numba 类型,但 as_numba_type 用于推断 Python 类型的 Numba 类型。对于简单的情况,我们可以简单地注册 Python 类型 Interval 对应于 Numba 类型 interval_type

from numba.extending import as_numba_type

as_numba_type.register(Interval, interval_type)

注意,as_numba_type 仅用于在编译时从类型注解中推断类型。上面的 typeof 注册表用于在运行时推断对象的类型。

操作的类型推断

我们希望能够在 Numba 函数中构造区间对象,因此我们必须教 Numba 识别两个参数的 Interval(lo, hi) 构造函数。参数应为浮点数:

from numba.extending import type_callable

@type_callable(Interval)
def type_interval(context):
    def typer(lo, hi):
        if isinstance(lo, types.Float) and isinstance(hi, types.Float):
            return interval_type
    return typer

The type_callable() 装饰器指定在为给定的可调用对象(此处为 Interval 类本身)运行类型推断时应调用装饰函数。装饰函数必须简单地返回一个类型函数,该函数将以参数类型调用。这种看似复杂的设置的原因是为了使类型函数具有与类型化可调用对象*完全*相同的签名。这允许正确处理关键字参数。

装饰函数接收的 context 参数在更复杂的情况下非常有用,其中计算可调用对象的返回类型需要解析其他类型。

扩展降低层

我们已经完成了向 Numba 介绍我们的类型推断添加内容。现在我们必须教会 Numba 如何为新操作实际生成代码和数据。

定义原生区间的数据模型

一般来说,nopython 模式 不适用于由 CPython 解释器生成的 Python 对象。解释器使用的表示方法对于快速的原生代码来说效率太低。因此,nopython 模式 中支持的每种类型都必须定义一个量身定制的原生表示,也称为 数据模型

数据模型的一个常见案例是类似于 C struct 的不可变结构化数据模型。我们的区间数据类型恰好属于这一类别,以下是该数据类型可能的数据模型:

from numba.extending import models, register_model

@register_model(IntervalType)
class IntervalModel(models.StructModel):
    def __init__(self, dmm, fe_type):
        members = [('lo', types.float64),
                   ('hi', types.float64),]
        models.StructModel.__init__(self, dmm, fe_type, members)

这指示 Numba 类型 IntervalType 的值(或其任何实例)表示为一个包含两个字段 lohi 的结构,每个字段都是一个双精度浮点数(types.float64)。

备注

可变类型需要更复杂的数据模型,以便在修改后能够持久保存其值。它们通常不能像不可变类型那样存储并通过堆栈或寄存器传递。

暴露数据模型属性

我们希望数据模型属性 lohi 在 Numba 函数中以相同名称暴露。Numba 提供了一个便利函数来实现这一点:

from numba.extending import make_attribute_wrapper

make_attribute_wrapper(IntervalType, 'lo', 'lo')
make_attribute_wrapper(IntervalType, 'hi', 'hi')

这将暴露属性为只读模式。如上所述,可写属性不适合此模型。

公开属性

由于 width 属性是计算得出的,而不是存储在结构中,我们不能像对 lohi 那样简单地暴露它。我们必须显式地重新实现它:

from numba.extending import overload_attribute

@overload_attribute(IntervalType, "width")
def get_width(interval):
    def getter(interval):
        return interval.hi - interval.lo
    return getter

你可能会问,为什么我们不需要为此属性暴露一个类型推断钩子?答案是 @overload_attribute 是高级API的一部分:它在一个API中结合了类型推断和代码生成。

实现构造函数

现在我们想要实现带有两个参数的 Interval 构造函数:

from numba.extending import lower_builtin
from numba.core import cgutils

@lower_builtin(Interval, types.Float, types.Float)
def impl_interval(context, builder, sig, args):
    typ = sig.return_type
    lo, hi = args
    interval = cgutils.create_struct_proxy(typ)(context, builder)
    interval.lo = lo
    interval.hi = hi
    return interval._getvalue()

这里还有一些其他的内容。@lower_builtin 装饰了给定可调用对象或操作(这里是 Interval 构造函数)的实现,用于某些特定的参数类型。这允许定义给定操作的类型特定实现,这对于像 len() 这样重载严重的函数非常重要。

types.Float 是所有浮点类型的类(types.float64types.Float 的一个实例)。通常,根据类的类型匹配参数类型比根据特定实例更具有未来性(然而,在 返回 类型时——主要是在类型推断阶段——,你通常必须返回一个类型实例)。

cgutils.create_struct_proxy()interval._getvalue() 由于 Numba 传递值的方式,存在一些样板代码。值作为 llvmlite.ir.Value 的实例传递,这可能过于有限:特别是 LLVM 结构值非常底层。结构代理是围绕 LLVM 结构值的临时包装器,允许轻松获取或设置结构的成员。_getvalue() 调用只是从包装器中获取 LLVM 值。

装箱和拆箱

如果你在这个时候尝试使用一个 Interval 实例,你肯定会遇到错误 “无法将 Interval 转换为原生值”。这是因为 Numba 还不知道如何从 Python 的 Interval 实例创建一个原生区间值。让我们教它如何做到这一点:

from numba.extending import unbox, NativeValue
from contextlib import ExitStack

@unbox(IntervalType)
def unbox_interval(typ, obj, c):
    """
    Convert a Interval object to a native interval structure.
    """
    is_error_ptr = cgutils.alloca_once_value(c.builder, cgutils.false_bit)
    interval = cgutils.create_struct_proxy(typ)(c.context, c.builder)

    with ExitStack() as stack:
        lo_obj = c.pyapi.object_getattr_string(obj, "lo")
        with cgutils.early_exit_if_null(c.builder, stack, lo_obj):
            c.builder.store(cgutils.true_bit, is_error_ptr)
        lo_native = c.unbox(types.float64, lo_obj)
        c.pyapi.decref(lo_obj)
        with cgutils.early_exit_if(c.builder, stack, lo_native.is_error):
            c.builder.store(cgutils.true_bit, is_error_ptr)

        hi_obj = c.pyapi.object_getattr_string(obj, "hi")
        with cgutils.early_exit_if_null(c.builder, stack, hi_obj):
            c.builder.store(cgutils.true_bit, is_error_ptr)
        hi_native = c.unbox(types.float64, hi_obj)
        c.pyapi.decref(hi_obj)
        with cgutils.early_exit_if(c.builder, stack, hi_native.is_error):
            c.builder.store(cgutils.true_bit, is_error_ptr)

        interval.lo = lo_native.value
        interval.hi = hi_native.value

    return NativeValue(interval._getvalue(), is_error=c.builder.load(is_error_ptr))

Unbox 是“将 Python 对象转换为原生值”的另一个名称(它符合 Python 对象作为包含简单原生值的复杂盒子的概念)。该函数返回一个 NativeValue 对象,该对象允许其调用者访问计算出的原生值、错误位以及可能的其他信息。

上面的代码片段大量使用了 c.pyapi 对象,该对象提供了对 Python 解释器的 C API 的一个子集的访问。注意使用 early_exit_if_null 来检测和处理在解包对象时可能发生的任何错误(例如,尝试传递 Interval('a', 'b'))。

我们还希望执行反向操作,称为 装箱 ,以便从 Numba 函数返回区间值:

from numba.extending import box

@box(IntervalType)
def box_interval(typ, val, c):
    """
    Convert a native interval structure to an Interval object.
    """
    ret_ptr = cgutils.alloca_once(c.builder, c.pyapi.pyobj)
    fail_obj = c.pyapi.get_null_object()

    with ExitStack() as stack:
        interval = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)
        lo_obj = c.box(types.float64, interval.lo)
        with cgutils.early_exit_if_null(c.builder, stack, lo_obj):
            c.builder.store(fail_obj, ret_ptr)

        hi_obj = c.box(types.float64, interval.hi)
        with cgutils.early_exit_if_null(c.builder, stack, hi_obj):
            c.pyapi.decref(lo_obj)
            c.builder.store(fail_obj, ret_ptr)

        class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(Interval))
        with cgutils.early_exit_if_null(c.builder, stack, class_obj):
            c.pyapi.decref(lo_obj)
            c.pyapi.decref(hi_obj)
            c.builder.store(fail_obj, ret_ptr)

        # NOTE: The result of this call is not checked as the clean up
        # has to occur regardless of whether it is successful. If it
        # fails `res` is set to NULL and a Python exception is set.
        res = c.pyapi.call_function_objargs(class_obj, (lo_obj, hi_obj))
        c.pyapi.decref(lo_obj)
        c.pyapi.decref(hi_obj)
        c.pyapi.decref(class_obj)
        c.builder.store(res, ret_ptr)

    return c.builder.load(ret_ptr)

使用它

nopython 模式 函数现在能够使用 Interval 对象以及你在它们上定义的各种操作。你可以尝试例如以下函数:

from numba import njit

@njit
def inside_interval(interval, x):
    return interval.lo <= x < interval.hi

@njit
def interval_width(interval):
    return interval.width

@njit
def sum_intervals(i, j):
    return Interval(i.lo + j.lo, i.hi + j.hi)

结论

我们已经展示了如何完成以下任务:

  • 通过子类化 Type 类来定义一个新的 Numba 类型类

  • 为非参数类型定义一个Numba的单例类型实例

  • 使用 typeof_impl.register 教 Numba 如何推断某个类的 Python 值的 Numba 类型。

  • 教 Numba 如何推断 Python 类型本身的 Numba 类型,使用 as_numba_type.register

  • 使用 StructModelregister_model 定义 Numba 类型的数据模型

  • 使用 @box 装饰器为 Numba 类型实现一个装箱函数

  • 使用 @unbox 装饰器和 NativeValue 类为 Numba 类型实现一个拆箱函数

  • 使用 @type_callable@lower_builtin 装饰器类型化并实现一个可调用对象

  • 使用 make_attribute_wrapper 便利函数公开一个只读的结构属性

  • 使用 @overload_attribute 装饰器实现一个只读属性