使用 @stencil
装饰器
模板是一种常见的计算模式,其中数组元素根据称为模板内核的某种固定模式进行更新。Numba 提供了 @stencil
装饰器,使用户可以轻松指定模板内核,然后 Numba 生成必要的循环代码,将该内核应用于某些输入数组。因此,stencil 装饰器允许更清晰、更简洁的代码,并且与 并行 jit 选项 结合使用时,可以通过并行化模板执行来实现更高的性能。
基本用法
@stencil
装饰器的使用示例:
from numba import stencil
@stencil
def kernel1(a):
return 0.25 * (a[0, 1] + a[1, 0] + a[0, -1] + a[-1, 0])
模板内核通过看起来像标准Python函数定义的方式指定,但在数组索引方面有不同的语义。尽管根据内核定义可能具有不同的类型,模板生成与输入数组大小和形状相同的输出数组。从概念上讲,模板内核为输出数组中的每个元素运行一次。模板内核的返回值是写入该特定元素的输出数组的值。
参数 a
表示应用核的输入数组。对该数组的索引是相对于当前正在处理的输出数组的元素进行的。例如,如果正在处理元素 (x, y)
,那么模板核中的 a[0, 0]
对应于输入数组中的 a[x + 0, y + 0]
。类似地,模板核中的 a[-1, 1]
对应于输入数组中的 a[x - 1, y + 1]
。
根据指定的内核,内核可能不适用于输出数组的边界,因为这可能会导致访问输入数组越界。stencil装饰器处理这种情况的方式取决于选择了哪种 func_or_mode。默认模式是将输出数组的边界元素设置为零。
要在输入数组上调用模板,请像调用常规函数一样调用模板,并将输入数组作为参数传递。例如,使用上面定义的内核:
>>> import numpy as np
>>> input_arr = np.arange(100).reshape((10, 10))
array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
[20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35, 36, 37, 38, 39],
[40, 41, 42, 43, 44, 45, 46, 47, 48, 49],
[50, 51, 52, 53, 54, 55, 56, 57, 58, 59],
[60, 61, 62, 63, 64, 65, 66, 67, 68, 69],
[70, 71, 72, 73, 74, 75, 76, 77, 78, 79],
[80, 81, 82, 83, 84, 85, 86, 87, 88, 89],
[90, 91, 92, 93, 94, 95, 96, 97, 98, 99]])
>>> output_arr = kernel1(input_arr)
array([[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 11., 12., 13., 14., 15., 16., 17., 18., 0.],
[ 0., 21., 22., 23., 24., 25., 26., 27., 28., 0.],
[ 0., 31., 32., 33., 34., 35., 36., 37., 38., 0.],
[ 0., 41., 42., 43., 44., 45., 46., 47., 48., 0.],
[ 0., 51., 52., 53., 54., 55., 56., 57., 58., 0.],
[ 0., 61., 62., 63., 64., 65., 66., 67., 68., 0.],
[ 0., 71., 72., 73., 74., 75., 76., 77., 78., 0.],
[ 0., 81., 82., 83., 84., 85., 86., 87., 88., 0.],
[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
>>> input_arr.dtype
dtype('int64')
>>> output_arr.dtype
dtype('float64')
注意,stencil装饰器已经确定指定stencil内核的输出类型为 float64
,因此创建了 float64
类型的输出数组,而输入数组的类型为 int64
。
模板参数
Stencil 内核定义可以接受任意数量的参数,但有以下规定。第一个参数必须是一个数组。输出数组的大小和形状将与第一个参数相同。其他参数可以是标量或数组。对于数组参数,这些数组在每个维度上必须至少与第一个参数(数组)一样大。所有此类输入数组参数的数组索引都是相对的。
内核形状推断和边界处理
在上面的例子和大多数情况下,模板内核中的数组索引将仅使用 Integer
字面量。在这种情况下,模板装饰器能够分析模板内核以确定其大小。在上面的例子中,模板装饰器确定内核的形状为 3 x 3
,因为索引 -1
到 1
被用于第一和第二维度。请注意,模板装饰器也能正确处理非对称和非方形的模板内核。
基于模板核的大小,模板装饰器能够计算输出数组中边框的大小。如果在将核应用于输入数组的某个元素时会导致索引越界,那么该元素属于输出数组的边框。在上面的例子中,每个维度中的点 -1
和 +1
被访问,因此输出数组在所有维度上都有一个大小为1的边框。
并行模式能够从简单表达式中推断出内核索引为常量(如果可能的话)。例如:
@njit(parallel=True)
def stencil_test(A):
c = 2
B = stencil(
lambda a, c: 0.3 * (a[-c+1] + a[0] + a[c-1]))(A, c)
return B
Stencil 装饰器选项
备注
未来可能会增强模板装饰器,以提供额外的边框处理机制。目前,仅实现了一种行为,"constant"``(详见下文的 ``func_or_mode
)。
neighborhood
有时,仅使用 Integer
字面量编写模板内核可能不太方便。例如,假设我们想要计算时间序列数据的后30天移动平均值。可以写成 (a[-29] + a[-28] + ... + a[-1] + a[0]) / 30
,但模板装饰器提供了使用 neighborhood
选项的更简洁形式:
@stencil(neighborhood = ((-29, 0),))
def kernel2(a):
cumul = 0
for i in range(-29, 1):
cumul += a[i]
return cumul / 30
neighborhood 选项是一个元组的元组。外部元组的长度等于输入数组的维数。内部元组的长度总是为二,因为内部元组的每个元素对应于在相应维度中使用的最小和最大索引偏移量。
如果用户指定了一个邻域,但内核访问了指定邻域之外的元素,行为是未定义的。
func_or_mode
可选的 func_or_mode
参数控制输出数组边界的处理方式。目前,仅支持一个值 "constant"
。在 constant
模式下,当卷积核访问输入数组有效范围之外的元素时,不会应用卷积核。在这种情况下,输出数组中的这些元素将被赋值为一个常数值,该值由 cval
参数指定。
cval
可选的 cval 参数默认为零,但可以设置为任何所需值,当 func_or_mode
参数设置为 constant
时,该值将用于输出数组的边界。在所有其他模式下,cval 参数将被忽略。cval 参数的类型必须与模板内核的返回类型匹配。如果用户希望输出数组由特定类型构建,则应确保模板内核返回该类型。
standard_indexing
默认情况下,stencil 内核中的所有数组访问都按照上述方式处理为相对索引。然而,有时可能需要将一个辅助数组(例如权重数组)传递给 stencil 内核,并让该数组使用标准的 Python 索引而不是相对索引。为此,stencil 装饰器选项 standard_indexing
的值是一个字符串集合,这些字符串的名称与 stencil 函数中要使用标准 Python 索引而不是相对索引访问的参数相匹配:
@stencil(standard_indexing=("b",))
def kernel3(a, b):
return a[-1] * b[0] + a[0] + b[1]
StencilFunc
stencil装饰器返回一个类型为``StencilFunc``的可调用对象。一个``StencilFunc``对象包含多个属性,但用户可能只对``neighborhood``属性感兴趣。如果将``neighborhood``选项传递给stencil装饰器,则提供的邻域将存储在此属性中。否则,在第一次执行或编译时,系统会如上所述计算邻域,然后将计算出的邻域存储到此属性中。用户可以检查此属性,以验证计算出的邻域是否正确。
Stencil 调用选项
在内部,stencil装饰器将指定的stencil内核转换为常规的Python函数。此函数将具有与stencil内核定义中指定的相同参数,但还将包括以下可选参数。
out
可选的 out
参数被添加到 Numba 生成的每个模板函数中。如果指定,out
参数告诉 Numba 用户正在提供他们自己的预分配数组,用于模板输出的存储。在这种情况下,模板函数将不会分配自己的输出数组。用户应确保模板内核的返回类型可以安全地转换为用户指定的输出数组的元素类型,遵循 NumPy ufunc 转换规则。
下面展示了一个使用示例:
>>> import numpy as np
>>> input_arr = np.arange(100).reshape((10, 10))
>>> output_arr = np.full(input_arr.shape, 0.0)
>>> kernel1(input_arr, out=output_arr)