jax.numpy.right_shift#
- jax.numpy.right_shift(x1, x2, /)[源代码][源代码]#
将
x1
的位向右移动x2
指定的数量。JAX implementation of
numpy.right_shift
.- 参数:
x1 (ArrayLike) – 输入数组,仅接受无符号整数子类型
x2 (ArrayLike) – 将
x1
中每个元素向右移动的位数,仅接受整数子类型
- 返回:
一个类似数组的对象,包含
x1
的元素按x2
中指定的数量右移后的结果,形状与x1
和x2
的广播形状相同。- 返回类型:
备注
如果
x1.shape != x2.shape
,它们必须能够广播到一个共同的形状,这个共同的形状也将是输出的形状。将标量 x1 右移标量 x2 等同于x1 // 2**x2
。示例
>>> def print_binary(x): ... return [bin(int(val)) for val in x]
>>> x1 = jnp.array([1, 2, 4, 8]) >>> print_binary(x1) ['0b1', '0b10', '0b100', '0b1000'] >>> x2 = 1 >>> result = jnp.right_shift(x1, x2) >>> result Array([0, 1, 2, 4], dtype=int32) >>> print_binary(result) ['0b0', '0b1', '0b10', '0b100']
>>> x1 = 16 >>> print_binary([x1]) ['0b10000'] >>> x2 = jnp.array([1, 2, 3, 4]) >>> result = jnp.right_shift(x1, x2) >>> result Array([8, 4, 2, 1], dtype=int32) >>> print_binary(result) ['0b1000', '0b100', '0b10', '0b1']