关于字面量类型的注释
备注
本文档描述了一项高级功能,旨在克服与类型相关的编译机制的一些限制。
某些功能需要在编译期间根据字面值进行专门化,以生成Numba成功编译所需的类型稳定代码。这可以通过将字面值通过类型系统传播来实现。Numba将内联字面值识别为 numba.types.Literal
。例如:
def foo(x):
a = 123
return bar(x, a)
Numba 会将 a
的类型推断为 Literal[int](123)
。随后,bar()
的定义可以专门化其实现,知道第二个参数是一个值为 123
的 int
。
Literal
类型
与 Literal
类型相关的类和方法。
指定为字面类型
要在计划进行JIT编译的代码中将值指定为 Literal
类型,请使用以下函数:
- numba.literally(obj)[源代码]
强制 Numba 将 obj 解释为字面值。
obj 必须是字面量或调用函数的参数,其中该参数必须绑定到字面量。字面量要求会沿调用栈向上传播。
这个函数被编译器拦截,以改变编译行为,将相应的函数参数包装为
Literal
。它在 nopython 模式(解释器和 objectmode)之外 没有效果。当前的实现通过两种方式检测字面参数:
通过编译器传递扫描
literally
的使用。literally
被重载以引发numba.errors.ForceLiteralArg
来通知调度器以不同的方式处理相应的参数。此模式用于支持间接使用(通过函数调用)。
此函数的执行语义等同于一个恒等函数。
查看 numba/tests/test_literal_dispatch.py 以获取示例。
代码示例
1 import numba
2
3 def power(x, n):
4 raise NotImplementedError
5
6 @numba.extending.overload(power)
7 def ov_power(x, n):
8 if isinstance(n, numba.types.Literal):
9 # only if `n` is a literal
10 if n.literal_value == 2:
11 # special case: square
12 print("square")
13 return lambda x, n: x * x
14 elif n.literal_value == 3:
15 # special case: cubic
16 print("cubic")
17 return lambda x, n: x * x * x
18 else:
19 # If `n` is not literal, request literal dispatch
20 return lambda x, n: numba.literally(n)
21
22 print("generic")
23 return lambda x, n: x ** n
24
25 @numba.njit
26 def test_power(x, n):
27 return power(x, n)
28
29 # should print "square" and "9"
30 print(test_power(3, 2))
31
32 # should print "cubic" and "27"
33 print(test_power(3, 3))
34
35 # should print "generic" and "81"
36 print(test_power(3, 4))
37
内部细节
在内部,编译器会引发一个 ForceLiteralArgs
异常,以通知调度器使用 Literal
类型包装指定的参数。
- class numba.errors.ForceLiteralArg(arg_indices, fold_arguments=None, loc=None)
一个伪异常,用于指示调度器按字面类型输入参数
- 属性:
- 请求的参数frozenset[int]
参数的请求位置。
- __init__(arg_indices, fold_arguments=None, loc=None)
- 参数:
- arg_indicesSequence[int]
参数的请求位置。
- fold_arguments: 可调用对象
一个函数
(tuple, dict) -> tuple
,它绑定并展平args
和kwargs
。- locnumba.ir.Loc 或 None
- __or__(other)
与 self.combine(other) 相同
- combine(other)
通过或运算请求的参数返回一个新实例。
扩展内部
@overload
扩展可以在实现体内部像在普通jit代码中一样使用 literally
。
通过使用以下内容,可以显式处理字面要求:
- class numba.extending.SentryLiteralArgs(literal_args)[源代码]
- 参数:
- literal_argsSequence[str]
字面参数的名称序列
示例
以下行:
>>> SentryLiteralArgs(literal_args).for_pysig(pysig).bind(*args, **kwargs)
等同于:
>>> sentry_literal_args(pysig, literal_args, args, kwargs)