jax.numpy.right_shift

目录

jax.numpy.right_shift#

jax.numpy.right_shift(x1, x2, /)[源代码][源代码]#

x1 的位向右移动 x2 指定的数量。

JAX implementation of numpy.right_shift.

参数:
  • x1 (ArrayLike) – 输入数组,仅接受无符号整数子类型

  • x2 (ArrayLike) – 将 x1 中每个元素向右移动的位数,仅接受整数子类型

返回:

一个类似数组的对象,包含 x1 的元素按 x2 中指定的数量右移后的结果,形状与 x1x2 的广播形状相同。

返回类型:

Array

备注

如果 x1.shape != x2.shape,它们必须能够广播到一个共同的形状,这个共同的形状也将是输出的形状。将标量 x1 右移标量 x2 等同于 x1 // 2**x2

示例

>>> def print_binary(x):
...   return [bin(int(val)) for val in x]
>>> x1 = jnp.array([1, 2, 4, 8])
>>> print_binary(x1)
['0b1', '0b10', '0b100', '0b1000']
>>> x2 = 1
>>> result = jnp.right_shift(x1, x2)
>>> result
Array([0, 1, 2, 4], dtype=int32)
>>> print_binary(result)
['0b0', '0b1', '0b10', '0b100']
>>> x1 = 16
>>> print_binary([x1])
['0b10000']
>>> x2 = jnp.array([1, 2, 3, 4])
>>> result = jnp.right_shift(x1, x2)
>>> result
Array([8, 4, 2, 1], dtype=int32)
>>> print_binary(result)
['0b1000', '0b100', '0b10', '0b1']