jax.numpy.result_type

目录

jax.numpy.result_type#

jax.numpy.result_type(*args)[源代码][源代码]#

返回应用 NumPy 后的类型

LAX 后端实现的 numpy.result_type()

原始文档字符串如下。

类型提升规则应用于参数。

NumPy 中的类型提升与 C++ 等语言的规则类似,但有一些细微的差别。当同时使用标量和数组时,数组的类型优先,并且会考虑标量的实际值。

例如,计算 3*a,其中 a 是一个 32 位浮点数的数组,直观上应该得到一个 32 位浮点数的输出。如果 3 是一个 32 位整数,NumPy 规则表明它不能无损地转换为 32 位浮点数,因此结果类型应该是 64 位浮点数。通过检查常量 ‘3’ 的值,我们看到它适合一个 8 位整数,可以无损地转换为 32 位浮点数。

返回:

out – 结果类型。

返回类型:

dtype

参数:

args (Any)