支持的 NumPy 功能

Numba 的一个目标是与 NumPy 无缝集成。NumPy 数组为同质数据集提供了一种高效的存储方法。NumPy dtypes 在编译时提供了有用的类型信息,而内存中潜在大量数据的常规、结构化存储为代码生成提供了理想的内存布局。Numba 擅长生成在 NumPy 数组之上执行的代码。

Numba 中的 NumPy 支持有多种形式:

  • Numba 理解对 NumPy ufuncs 的调用,并且能够为其中的许多函数生成等效的本机代码。

  • NumPy 数组在 Numba 中直接得到支持。由于在可能的情况下索引被降低为直接内存访问,因此对 NumPy 数组的访问非常高效。

  • Numba 能够生成 ufuncsgufuncs。这意味着可以在 Python 中实现 ufuncs 和 gufuncs,获得与使用 NumPy C API 在 C 扩展模块中实现的 ufuncs/gufuncs 相当的速度。

以下章节主要关注在 nopython 模式 中支持的 NumPy 功能,除非另有说明。

标量类型

Numba 支持以下 NumPy 标量类型:

  • 整数:所有有符号或无符号的整数,宽度可达64位

  • 布尔值

  • 实数: 单精度 (32位) 和双精度 (64位) 实数

  • 复数: 单精度 (2x32位) 和双精度 (2x64位) 复数

  • 日期时间与时间戳: 任意单位

  • **字符序列**(但无法对其进行操作)

  • 结构化标量: 由上述任何类型和上述类型的数组组成的结构化标量

以下标量类型和特性不受支持:

  • 任意 Python 对象

  • 半精度与扩展精度 实数和复数

  • 嵌套结构化标量 结构化标量的字段可能不包含其他结构化标量

NumPy 标量支持的操作几乎与 intfloat 等内置类型相同。你可以使用类型的构造函数从不同类型或宽度进行转换。此外,你可以使用 view(np.<dtype>) 方法对相同宽度的所有 intfloat 类型进行位转换。但是,你必须在 jitted 函数中使用 NumPy 构造函数来定义标量。例如,以下代码将正常工作:

>>> import numpy as np
>>> from numba import njit
>>> @njit
... def bitcast():
...     i = np.int64(-1)
...     print(i.view(np.uint64))
...
>>> bitcast()
18446744073709551615

而以下内容将不起作用:

>>> import numpy as np
>>> from numba import njit
>>> @njit
... def bitcast(i):
...     print(i.view(np.uint64))
...
>>> bitcast(np.int64(-1))
---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
    ...
TypingError: Failed in nopython mode pipeline (step: ensure IR is legal prior to lowering)
'view' can only be called on NumPy dtypes, try wrapping the variable with 'np.<dtype>()'

File "<ipython-input-3-fc40aaab84c4>", line 3:
def bitcast(i):
    print(i.view(np.uint64))

结构化标量支持属性的获取和设置,以及使用常量字符串进行成员查找。存储在本地或全局元组中的字符串被视为常量字符串,并可用于成员查找。

import numpy as np
from numba import njit

arr = np.array([(1, 2)], dtype=[('a1', 'f8'), ('a2', 'f8')])
fields_gl = ('a1', 'a2')

@njit
def get_field_sum(rec):
    fields_lc = ('a1', 'a2')
    field_name1 = fields_lc[0]
    field_name2 = fields_gl[1]
    return rec[field_name1] + rec[field_name2]

get_field_sum(arr[0])  # returns 3

也可以将本地或全局元组与 literal_unroll 一起使用:

import numpy as np
from numba import njit, literal_unroll

arr = np.array([(1, 2)], dtype=[('a1', 'f8'), ('a2', 'f8')])
fields_gl = ('a1', 'a2')

@njit
def get_field_sum(rec):
    out = 0
    for f in literal_unroll(fields_gl):
        out += rec[f]
    return out

get_field_sum(arr[0])   # returns 3

记录子类型

警告

这是一个实验性功能。

Numba 允许 结构化标量的宽度子类型 。例如,dtype([('a', 'f8'), ('b', 'i8')]) 将被视为 dtype([('a', 'f8')]) 的子类型,因为后者是前者的严格子集,即字段 a 在两种类型中具有相同类型且处于相同位置。子类型关系在某些情况下会起作用,例如在不允许为特定输入编译的情况下,但输入是另一种允许类型的子类型。

import numpy as np
from numba import njit, typeof
from numba.core import types
record1 = np.array([1], dtype=[('a', 'f8')])[0]
record2 = np.array([(2,3)], dtype=[('a', 'f8'), ('b', 'f8')])[0]

@njit(types.float64(typeof(record1)))
def foo(rec):
    return rec['a']

foo(record1)
foo(record2)

如果没有子类型化,最后一行将会失败。通过子类型化,不会触发新的编译,但会使用为 record1 编译的函数来处理 record2

参见

NumPy 标量 参考。

数组类型

支持上述任何标量类型的 NumPy 数组 ,无论其形状或布局如何。

备注

NumPy MaskedArrays 不受支持。

数组访问

数组支持正常的迭代。完全支持基本的索引和切片,以及传递 None / np.newaxis 作为索引以增加结果维度。部分高级索引也得到支持:只允许一个高级索引,并且它必须是一维数组(它也可以与任意数量的基本索引组合)。

参见

NumPy 索引 参考。

结构化数组访问

Numba 目前支持通过属性访问结构化数组中单个元素的字段,以及通过获取和设置来访问。这稍微超出了 NumPy API 的范围,后者仅允许通过获取和设置来访问字段。例如:

from numba import njit
import numpy as np

record_type = np.dtype([("ival", np.int32), ("fval", np.float64)], align=True)

def f(rec):
    value = 2.5
    rec[0].ival = int(value)
    rec[0].fval = value
    return rec

arr = np.ones(1, dtype=record_type)

cfunc = njit(f)

# Works
print(cfunc(arr))

# Does not work
print(f(arr))

上述代码的输出结果为:

[(2, 2.5)]
Traceback (most recent call last):
  File "repro.py", line 22, in <module>
    print(f(arr))
  File "repro.py", line 9, in f
    rec[0].ival = int(value)
AttributeError: 'numpy.void' object has no attribute 'ival'

Numba编译版本的函数执行了,但纯Python版本因为不支持的属性访问而引发了一个错误。

备注

此行为最终将被弃用并移除。

属性

以下是支持的 NumPy 数组属性:

flags 对象

flags 属性返回的对象支持 contiguousc_contiguousf_contiguous 属性。

flat 对象

flat 属性返回的对象支持迭代和索引,但请注意:在非C连续数组上进行索引非常慢。

realimag 属性

无论数据类型如何,NumPy 都支持这些属性,但 Numba 选择限制它们的支持以避免潜在的用户错误。对于数值数据类型,Numba 遵循 NumPy 的行为。real 属性返回复数数组实部的视图,对于其他数值数据类型,它表现为恒等函数。imag 属性返回复数数组虚部的视图,对于其他数值数据类型,它返回一个形状和数据类型相同的全零数组。对于非数值数据类型,包括所有结构化/记录数据类型,使用这些属性将导致编译时(TypingError)错误。这种行为与 NumPy 不同,但这是为了避免与这些属性重叠的字段名可能引起的混淆。

计算

以下 NumPy 数组的方法在其基本形式中得到支持(没有任何可选参数):

相应的顶级 NumPy 函数(如 numpy.prod())同样得到支持。

其他方法

以下是 NumPy 数组支持的方法:

  • argmax() (支持 axis 关键字参数)。

  • argmin() (支持 axis 关键字参数)。

  • numpy.argpartition() (仅前两个参数)

  • argsort() (支持 kind 关键字参数,值可以是 'quicksort''mergesort'

  • astype() (仅限单参数形式)

  • copy() (无参数)

  • dot() (仅限单参数形式)

  • flatten() (无顺序参数;仅 ‘C’ 顺序)

  • item() (无参数)

  • itemset() (仅限单参数形式)

  • ptp() (不带参数)

  • ravel() (无顺序参数;仅’C’顺序)

  • repeat() (无轴参数)

  • reshape() (仅限单参数形式)

  • sort() (无参数)

  • sum() (可以带 axis 和/或 dtype 参数。)

    • axis 仅支持 integer 值。

    • 如果 axis 参数是一个编译时常量,所有有效值都支持。超出范围的值将在编译时导致 LoweringError

    • 如果 axis 参数不是编译时常量,则仅支持 0 到 3 的值。超出范围的值将导致运行时异常。

    • dtype 参数支持所有数值 dtypestimedelta 数组可以用作输入数组,但 timedelta 不支持作为 dtype 参数。

    • 当给定 dtype 时,它决定了内部累加器的类型。如果没有给定,则根据输入数组的 dtype 自动选择,主要遵循与 NumPy 相同的规则。然而,在 64 位 Windows 上,Numba 对整数输入使用 64 位累加器(int32 输入使用 int64uint32 输入使用 uint64),而在这些情况下,NumPy 会使用 32 位累加器。

  • transpose()

  • view() (仅限单参数形式)

  • __contains__()

在适用的情况下,相应的顶级 NumPy 函数(如 numpy.argmax())同样得到支持。

警告

排序可能比 NumPy 的实现稍微慢一些。

函数

线性代数

基本线性代数支持对浮点数和复数的1维和2维连续数组的运算:

备注

这些功能的实现需要安装 SciPy。

简化

以下归约函数受支持:

多项式

支持以下多项式类:* numpy.polynomial.polynomial.Polynomial (仅前三个参数)

支持以下多项式函数:* numpy.polynomial.polynomial.polyadd() * numpy.polynomial.polynomial.polydiv() * numpy.polynomial.polynomial.polyint() (仅前两个参数) * numpy.polynomial.polynomial.polymul() * numpy.polynomial.polynomial.polysub() * numpy.polynomial.polynomial.polyval() (参数 tensor 必须是布尔常量) * numpy.polynomial.polyutils.as_series() * numpy.polynomial.polyutils.trimseq()

其他功能

支持以下顶级函数:

以下构造函数均被支持,既可以接受数值输入(用于构造标量),也可以接受序列输入(用于构造数组):

  • numpy.bool_

  • numpy.complex64

  • numpy.complex128

  • numpy.float32

  • numpy.float64

  • numpy.int8

  • numpy.int16

  • numpy.int32

  • numpy.int64

  • numpy.intc

  • numpy.intp

  • numpy.uint8

  • numpy.uint16

  • numpy.uint32

  • numpy.uint64

  • numpy.uintc

  • numpy.uintp

以下机器参数类受到支持,所有属性均为纯数值:

字面数组

Python 和 Numba 都没有实际的数组字面量,但你可以通过调用 numpy.array() 在一个嵌套的元组上来构造任意数组:

a = numpy.array(((a, b, c), (d, e, f)))

(Numba 尚不支持嵌套列表)

模块

random

生成器对象

Numba 支持 numpy.random.Generator() 对象。从版本 0.56 开始,用户可以将单独的 NumPy Generator 对象传递到 Numba 函数中,并在函数内部使用它们的方法。随机数生成的算法与 NumPy 使用相同,因此在相同的参数下(同样适用 NumPy Generator 方法的相同文档注释),使用 NumPy 和 Numba 生成的随机数保持一致。当前 Numba 对 Generator 的支持不是线程安全的,因此我们不建议在具有并行执行逻辑的方法中使用 Generator 方法。

备注

NumPy 的 Generator 对象依赖于 BitGenerator 来管理状态并生成随机位,这些随机位随后被转换为有用的分布中的随机值。Numba 将 unbox Generator 对象,并使用 NumPy 的 ctypes 接口绑定来维护对底层 BitGenerator 对象的引用。因此,Generator 对象可以跨越 JIT 边界,并且其函数可以在 Numba-Jit 代码中使用。请注意,由于只维护了对 BitGenerator 对象的引用,任何在 Numba 代码外部对特定 Generator 对象状态的更改都会影响 Numba 代码内部的 Generator 状态。

x = np.random.default_rng(1)
y = np.random.default_rng(1)

size = 10

@numba.njit
def do_stuff(gen):
    return gen.random(size=int(size / 2))

original = x.random(size=size)
# [0.51182162 0.9504637  0.14415961 0.94864945 0.31183145
#  0.42332645 0.82770259 0.40919914 0.54959369 0.02755911]

numba_func_res = do_stuff(y)
# [0.51182162 0.9504637  0.14415961 0.94864945 0.31183145]

after_numba = y.random(size=int(size / 2))
# [0.42332645 0.82770259 0.40919914 0.54959369 0.02755911]

以下 生成器 方法受支持:

  • numpy.random.Generator().beta()

  • numpy.random.Generator().chisquare()

  • numpy.random.Generator().exponential()

  • numpy.random.Generator().f()

  • numpy.random.Generator().gamma()

  • numpy.random.Generator().geometric()

  • numpy.random.Generator().integers()lowhigh 都是必需参数。目前不支持低和高数组的值。)

  • numpy.random.Generator().laplace()

  • numpy.random.Generator().logistic()

  • numpy.random.Generator().lognormal()

  • numpy.random.Generator().logseries() (接受浮点值以及转换为浮点数的数据类型。目前不支持 p 的数组值。)

  • numpy.random.Generator().negative_binomial()

  • numpy.random.Generator().noncentral_chisquare() (接受浮点数值以及可以转换为浮点数的数据类型。目前不支持 dfnumnonc 的数组值。)

  • numpy.random.Generator().noncentral_f() (接受浮点数值以及可以转换为浮点数的数据类型。目前不支持 dfnumdfdennonc 的数组值。)

  • numpy.random.Generator().normal()

  • numpy.random.Generator().pareto()

  • numpy.random.Generator().permutation() (仅接受 NumPy ndarray 和整数。)

  • numpy.random.Generator().poisson()

  • numpy.random.Generator().power()

  • numpy.random.Generator().random()

  • numpy.random.Generator().rayleigh()

  • numpy.random.Generator().shuffle() (仅接受 NumPy ndarrays。)

  • numpy.random.Generator().standard_cauchy()

  • numpy.random.Generator().standard_exponential()

  • numpy.random.Generator().standard_gamma()

  • numpy.random.Generator().standard_normal()

  • numpy.random.Generator().standard_t()

  • numpy.random.Generator().triangular()

  • numpy.random.Generator().uniform()

  • numpy.random.Generator().wald()

  • numpy.random.Generator().weibull()

  • numpy.random.Generator().zipf()

备注

由于编译器之间的指令选择差异,在32位架构以及linux-aarch64和linux-ppc64le平台上,与NumPy相比,可能会出现数千个ULP的差异。对于Linux-x86_64、Windows-x86_64和macOS,这些差异不太明显(数十个ULP),但不保证遵循异常模式,在某些情况下可能会增加。

这些差异不太可能影响随机数生成的“质量”,因为它们是通过使用融合乘加而不是乘法后加法时发生的舍入变化引起的。

随机状态和传统随机数生成

Numba 支持 numpy.random 模块中的顶级函数,但不允许创建单独的 RandomState 实例。使用的算法与 标准随机模块 相同(因此适用相同的注释),但具有独立的内部状态:从一个生成器中播种或抽取数字不会影响另一个。

支持以下功能。

初始化

警告

从解释代码(包括从 对象模式 代码)调用 numpy.random.seed() 将种子化 NumPy 随机生成器,而不是 Numba 随机生成器。要种子化 Numba 随机生成器,请参见下面的示例。

from numba import njit
import numpy as np

@njit
def seed(a):
    np.random.seed(a)

@njit
def rand():
    return np.random.rand()


# Incorrect seeding
np.random.seed(1234)
print(rand())

np.random.seed(1234)
print(rand())

# Correct seeding
seed(1234)
print(rand())

seed(1234)
print(rand())

简单随机数据

排列

发行版

以下函数支持所有参数。

备注

从非Numba代码(或从 对象模式 代码)调用 numpy.random.seed() 将种子化NumPy随机生成器,而不是Numba随机生成器。

备注

自版本0.28.0起,生成器是线程安全和分叉安全的。每个线程和每个进程将生成独立的随机数流。

stride_tricks

以下是 numpy.lib.stride_tricks 模块中支持的函数:

  • as_strided()strides 参数是必需的,subok 参数不支持)

  • sliding_window_view()subok 参数不受支持,writeable 参数不受支持,返回的视图始终是可写的)

标准 ufuncs

Numba 的一个目标是让所有 NumPy 中的标准 ufuncs 都能被 Numba 理解。当在编译函数时发现支持的 ufunc,Numba 会将该 ufunc 映射到等效的本地代码。这使得那些 ufuncs 可以在以 nopython 模式 编译的 Numba 代码中使用。

限制

目前,只有一部分标准的 ufuncs 在 nopython 模式 下工作。以下是 Numba 所知的不同标准 ufuncs 的列表,按照与 NumPy 文档中相同的顺序排列。

数学运算

UFUNC

模式

名字

对象模式

nopython 模式

添加

减去

乘法

除法

logaddexp

logaddexp2

true_divide

floor_divide

negative

力量

float_power

余数

模块

fmod

divmod (*)

abs

绝对

fabs

打印

签名

conj

exp

exp2

日志

log2

log10

expm1

log1p

平方根

平方

cbrt

互惠

共轭

gcd

lcm

(*) 不支持在 timedelta 类型上使用

三角函数

UFUNC

模式

名字

对象模式

nopython 模式

sin

余弦

tan

反正弦

arccos

arctan

arctan2

hypot

sinh

cosh

tanh

反双曲正弦函数

arccosh

arctanh

deg2rad

rad2deg

弧度

位操作函数

UFUNC

模式

名字

对象模式

nopython 模式

按位与

按位或

按位异或

按位取反

反转

左移

右移

比较函数

UFUNC

模式

名字

对象模式

nopython 模式

更大

greater_equal

less

less_equal

不等于

等于

logical_and

logical_or

逻辑异或

logical_not

最大值

最小值

fmax

fmin

浮动函数

UFUNC

模式

名字

对象模式

nopython 模式

isfinite

isinf

isnan

signbit

copysign

nextafter

modf

ldexp

frexp

地板

ceil

截断

间距

日期时间函数

UFUNC

模式

名字

对象模式

nopython 模式