jax.lax.bitcast_convert_type

jax.lax.bitcast_convert_type#

jax.lax.bitcast_convert_type(operand, new_dtype)[源代码][源代码]#

逐元素位转换。

封装了 XLA 的 BitcastConvertType 操作符,该操作符执行从一种类型到另一种类型的位转换。

输出形状取决于输入和输出数据类型的大小,遵循以下逻辑:

if new_dtype.itemsize == operand.dtype.itemsize:
  output_shape = operand.shape
if new_dtype.itemsize < operand.dtype.itemsize:
  output_shape = (*operand.shape, operand.dtype.itemsize // new_dtype.itemsize)
if new_dtype.itemsize > operand.dtype.itemsize:
  assert operand.shape[-1] * operand.dtype.itemsize == new_dtype.itemsize
  output_shape = operand.shape[:-1]
参数:
  • operand (ArrayLike) – 要转换的数组或标量值

  • new_dtype (DTypeLike) – 新类型。应为 NumPy 类型。

返回:

一个形状为 output_shape`(见上文)和类型为 `new_dtype 的数组,由与操作数相同的位构造。

返回类型:

Array