jax.Array.view

目录

jax.Array.view#

abstract Array.view(dtype=None, type=None)[源代码]#

返回数组的按位复制,视作新的数据类型。

这是一个围绕 jax.lax.bitcast_convert_type() 的更全面功能的包装器。

如果源数据类型和目标数据类型的位宽相同,结果将具有与输入数组相同的形状。如果目标数据类型的位宽与源数据类型不同,结果的最后一个轴的大小将相应调整。

>>> jnp.zeros([1,2,3], dtype=jnp.int16).view(jnp.int8).shape
(1, 2, 6)
>>> jnp.zeros([1,2,4], dtype=jnp.int8).view(jnp.int16).shape
(1, 2, 2)

涉及布尔类型的转换在所有情况下并未明确定义。关于上述结果的形状,布尔类型被视为具有8位的宽度。然而,当转换为布尔数组时,输入应仅包含0或1字节。否则,结果可能不可预测,或者可能会根据结果的使用方式而变化。

此转换是保证且安全的:

>>> jnp.array([1, 0, 1], dtype=jnp.int8).view(jnp.bool_)
Array([ True, False,  True], dtype=bool)

然而,对于涉及此类视图的任何表达式的结果,没有任何保证,例如:jnp.array([1, 2, 3], dtype=jnp.int8).view(jnp.bool_)。特别是,结果可能会在JAX版本之间以及根据平台而变化。要安全地将此类数组转换为布尔数组,请将其与`0`进行比较:

>>> jnp.array([1, 2, 0], dtype=jnp.int8) != 0
Array([ True,  True, False], dtype=bool)
参数:
  • self (Array)

  • dtype (DTypeLike | None)

  • type (None)

返回类型:

Array