jax.numpy.astype

目录

jax.numpy.astype#

jax.numpy.astype(x, dtype, /, *, copy=False, device=None)[源代码][源代码]#

将数组转换为指定的数据类型。

JAX 实现 numpy.astype()

这是通过 jax.lax.convert_element_type() 实现的,它在某些情况下可能与 numpy.astype() 的行为略有不同。特别是,浮点数到整数和整数到浮点数的转换细节是依赖于实现的。

参数:
  • x (ArrayLike) – 要转换的输入数组

  • dtype (DTypeLike | None) – 输出数据类型

  • copy (bool) – 如果为 True,则始终返回一个副本。如果为 False(默认),则仅在必要时返回一个副本。

  • device (xc.Device | Sharding | None) – 可选地指定输出将被提交到的设备。

返回:

一个与 x 形状相同的数组,包含指定 dtype 的值。

返回类型:

Array

参见

示例

>>> x = jnp.array([0, 1, 2, 3])
>>> x
Array([0, 1, 2, 3], dtype=int32)
>>> x.astype('float32')
Array([0.0, 1.0, 2.0, 3.0], dtype=float32)
>>> y = jnp.array([0.0, 0.5, 1.0])
>>> y.astype(int)  # truncates fractional values
Array([0, 0, 1], dtype=int32)