jax.lax.slice_in_dim

jax.lax.slice_in_dim#

jax.lax.slice_in_dim(operand, start_index, limit_index, stride=1, axis=0)[源代码][源代码]#

围绕 lax.slice() 的便捷包装器,仅应用于一个维度。

这实际上等同于 operand[..., start_index:limit_index:stride] ,其中索引应用于指定的轴。

参数:
  • operand (Array | np.ndarray) – 一个要切片的数组。

  • start_index (int | None) – 一个可选的起始索引(默认为零)

  • limit_index (int | None) – 一个可选的结束索引(默认为 operand.shape[axis])

  • stride (int) – 一个可选的步幅(默认为1)

  • axis (int) – 要应用切片操作的轴(默认为0)

返回:

包含切片内容的数组。

返回类型:

Array

示例

这是一个一维的例子:

>>> x = jnp.arange(4)
>>> lax.slice_in_dim(x, 1, 3)
Array([1, 2], dtype=int32)

以下是一些二维示例:

>>> x = jnp.arange(12).reshape(4, 3)
>>> x
Array([[ 0,  1,  2],
       [ 3,  4,  5],
       [ 6,  7,  8],
       [ 9, 10, 11]], dtype=int32)
>>> lax.slice_in_dim(x, 1, 3)
Array([[3, 4, 5],
       [6, 7, 8]], dtype=int32)
>>> lax.slice_in_dim(x, 1, 3, axis=1)
Array([[ 1,  2],
       [ 4,  5],
       [ 7,  8],
       [10, 11]], dtype=int32)