NumPy 中的数据类型提升#

当混合两种不同的数据类型时,NumPy 必须确定操作结果的适当 dtype.这一步骤被称为 提升寻找公共 dtype.

在典型情况下,用户不需要担心提升细节,因为提升步骤通常确保结果将匹配或超过输入的精度.

例如,当输入具有相同的 dtype 时,结果的 dtype 与输入的 dtype 匹配:

>>> np.int8(1) + np.int8(1)
np.int8(2)

混合两种不同的数据类型通常会产生一个具有更高精度输入数据类型的结果:

>>> np.int8(4) + np.int64(8)  # 64 > 8
np.int64(12)
>>> np.float32(3) + np.float16(3)  # 32 > 16
np.float32(6.0)

在典型情况下,这不会导致意外.然而,如果你使用非默认的 dtypes,如无符号整数和低精度浮点数,或者如果你混合使用 NumPy 整数、NumPy 浮点和 Python 标量,NumPy 的提升规则的一些细节可能就相关了.请注意,这些详细的规则并不总是与其他语言的规则相匹配 [1].

数值数据类型有四种”种类”,并且有一个自然的层次结构.

  1. 无符号整数 (uint)

  2. 有符号整数 (int)

  3. 浮点数 (float)

  4. 复杂 (complex)

除了种类之外,NumPy 数值类型还具有相关的精度,以位为单位指定.种类和精度共同指定数据类型.例如,``uint8`` 是一个使用 8 位存储的无符号整数.

操作的结果总是与输入的类型相同或更高.此外,结果的精度总是大于或等于输入的精度.仅此一点,就可能导致一些出乎意料的例子:

  1. 当混合浮点数和整数时,整数的精度可能会迫使结果变为更高精度的浮点数.例如,涉及 int64float16 的操作结果是 float64.

  2. 当混合使用相同精度的无符号和有符号整数时,结果将具有比任一输入 更高 的精度.此外,如果其中一个是已经具有64位精度的,则没有更高精度的整数可用,例如涉及 int64uint64 的操作会得到 float64.

请参阅下面的 数值提升 部分和图像以获取详细信息.

Python 标量的详细行为#

自从 NumPy 2.0 [2] 以来,在我们的推广规则中一个重要的点是,尽管涉及两个 NumPy dtypes 的操作永远不会丢失精度,但涉及 NumPy dtype 和 Python 标量(intfloatcomplex)的操作 可能 会丢失精度.例如,可能很直观地认为 Python 整数和 NumPy 整数之间操作的结果应该是 NumPy 整数.然而,Python 整数具有任意精度,而所有 NumPy dtypes 具有固定精度,因此 Python 整数的任意精度无法保留.

更一般地,NumPy 考虑 Python 标量的”种类”,但在确定结果 dtype 时忽略它们的精度.这通常很方便.例如,当使用低精度 dtype 的数组时,通常希望与 Python 标量的简单操作保留 dtype.

>>> arr_float32 = np.array([1, 2.5, 2.1], dtype="float32")
>>> arr_float32 + 10.0  # undesirable to promote to float64
array([11. , 12.5, 12.1], dtype=float32)
>>> arr_int16 = np.array([3, 5, 7], dtype="int16")
>>> arr_int16 + 10  # undesirable to promote to int64
array([13, 15, 17], dtype=int16)

在这两种情况下,结果的精度由 NumPy 的 dtype 决定.因此,``arr_float32 + 3.0`` 的行为与 arr_float32 + np.float32(3.0) 相同,而 arr_int16 + 10 的行为与 arr_int16 + np.int16(10.) 相同.

作为另一个例子,当混合 NumPy 整数与 Python floatcomplex 时,结果总是具有类型 float64complex128:

>> np.int16(1) + 1.0 np.float64(2.0)

然而,这些规则在与低精度数据类型一起工作时也可能导致令人惊讶的行为.

首先,由于在操作执行之前,Python 值被转换为 NumPy 值,当结果看起来很明显时,操作可能会失败并出现错误.例如,``np.int8(1) + 1000`` 无法继续,因为 1000 超过了 int8 的最大值.当 Python 标量无法强制转换为 NumPy 数据类型时,会引发错误:

>>> np.int8(1) + 1000
Traceback (most recent call last):
  ...
OverflowError: Python integer 1000 out of bounds for int8
>>> np.int64(1) * 10**100
Traceback (most recent call last):
...
OverflowError: Python int too large to convert to C long
>>> np.float32(1) + 1e300
np.float32(inf)
... RuntimeWarning: overflow encountered in cast

其次,由于Python的浮点数或整数精度总是被忽略,低精度的NumPy标量将继续使用其较低的精度,除非显式转换为更高精度的NumPy数据类型或Python标量(例如通过``int()``、float()``或``scalar.item()).这种较低的精度可能对某些计算有害,或导致不正确的结果,特别是在整数溢出的情况下:

>>> np.int8(100) + 100  # the result exceeds the capacity of int8
np.int8(-56)
... RuntimeWarning: overflow encountered in scalar add

请注意,当标量发生溢出时,NumPy 会发出警告,但对于数组不会;例如,``np.array(100, dtype=”uint8”) + 100`` 将 不会 发出警告.

数值提升#

以下图像显示了数值提升规则,垂直轴表示种类,水平轴表示精度.

../_images/nep-0050-promotion-no-fonts.svg

具有更高种类的输入 dtype 决定了结果 dtype 的种类.结果 dtype 的精度尽可能低,但在图表中不会出现在任一输入 dtype 的左侧.

请注意以下特定的规则和观察:

  1. 当 Python floatcomplex 与 NumPy 整数交互时,结果将是 float64``complex128``(黄色边框).NumPy 布尔值也将被转换为默认整数.[#default-int] 当涉及 NumPy 浮点值时,这并不相关.

  2. 精度是这样划分的 float16 < int16 < uint16 因为大的 uint16 不适合 int16 ,而大的 int16 在存储为 float16 时会失去精度.然而,这种模式被打破了,因为 NumPy 总是认为 float64complex128 是任何整数值的可接受提升结果.

  3. 一个特殊情况是,NumPy 会将许多有符号和无符号整数的组合提升为 float64.这里使用更高类型是因为没有有符号整数类型足够精确以容纳 uint64.

一般推广规则的例外#

在 NumPy 中,提升(promotion)指的是特定函数对结果的处理,在某些情况下,这意味着 NumPy 可能会偏离 np.result_type 给出的结果.

sumprod 的行为#

``np.sum`` 和 ``np.prod``: 在求和整数值(或布尔值)时,总是返回默认的整数类型.这通常是一个 int64.这样做的原因是,否则整数求和非常可能溢出并给出令人困惑的结果.这条规则也适用于底层的 np.add.reducenp.multiply.reduce.

使用NumPy或Python整数标量的显著行为#

NumPy 提升指的是结果数据类型和操作精度,但操作有时会决定结果.除法总是返回浮点值,而比较总是返回布尔值.

这导致了可能看起来像是规则的”例外”:

  • NumPy 与 Python 整数或混合精度整数的比较总是返回正确的结果.输入永远不会以丢失精度的方式进行转换.

  • 类型之间无法提升的相等比较将被视为全部 ``False``(相等)或全部 ``True``(不相等).

  • np.sin 这样的单目数学函数总是返回浮点值,通过将其转换为 float64 来接受任何 Python 整数输入.

  • 除法总是返回浮点值,因此也允许任何 NumPy 整数与任何 Python 整数值之间的除法,通过将两者都转换为 float64.

原则上,这些异常中的一些可能对其他函数有意义.如果你认为这是这种情况,请提出一个问题.

推广非数值数据类型#

NumPy 扩展了对非数值类型的提升,尽管在许多情况下提升没有明确定义,并且简单地被拒绝.

以下规则适用:

  • NumPy 字节字符串 (np.bytes_) 可以转换为 Unicode 字符串 (np.str_).然而,将字节转换为 Unicode 会因非 ASCII 字符而失败.

  • 出于某些目的,NumPy 几乎可以将任何其他数据类型提升为字符串.这适用于数组创建或连接.

  • np.array() 这样的数组构造函数在没有可行提升时会使用 object 数据类型.

  • 结构化 dtypes 可以在它们的字段名称和顺序匹配时进行提升.在这种情况下,所有字段都会单独提升.

  • NumPy timedelta 在某些情况下可以与整数进行提升.

备注

这些规则中的一些有些令人惊讶,并且正在考虑在未来进行更改.然而,任何不兼容的更改都必须权衡破坏现有代码的风险.如果您对提升工作有特定的想法,请提出问题.

推广的 dtype 实例的详细信息#

上述讨论主要涉及了混合不同 DType 类时的行为.附加到数组的 dtype 实例可以携带额外的信息,如字节顺序、元数据、字符串长度或精确的结构化 dtype 布局.

虽然结构化数据类型的字符串长度或字段名称很重要,但 NumPy 认为字节顺序、元数据以及结构化数据类型的精确布局是存储细节.在提升过程中,NumPy 不考虑这些存储细节:* 字节顺序转换为本地字节顺序.* 附加到数据类型的元数据可能会也可能不会保留.* 结果结构化数据类型将被打包(但如果输入是,则对齐).

这种行为是为大多数程序设计的最佳行为,在这些程序中,存储细节与最终结果无关,并且使用错误的字节顺序可能会大大降低评估速度.