jax.lax.动态切片

jax.lax.动态切片#

jax.lax.dynamic_slice(operand, start_indices, slice_sizes)[源代码][源代码]#

封装了XLA的 DynamicSlice 操作符。

参数:
  • operand (Array | np.ndarray) – 一个要切片的数组。

  • start_indices (Array | np.ndarray | Sequence[ArrayLike]) – 一个标量索引列表,每个维度一个。这些值可能是动态的。

  • slice_sizes (Shape) – 切片的大小。必须是一个非负整数的序列,其长度等于 ndim(operand)。在JIT编译的函数内部,只支持静态值(JIT内部的所有JAX数组必须具有静态已知的大小)。

返回:

包含切片内容的数组。

返回类型:

Array

示例

这是一个简单的二维动态切片:

>>> x = jnp.arange(12).reshape(3, 4)
>>> x
Array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]], dtype=int32)
>>> dynamic_slice(x, (1, 1), (2, 3))
Array([[ 5,  6,  7],
       [ 9, 10, 11]], dtype=int32)

注意对于请求的切片超出数组边界的情况,可能会出现令人惊讶的行为;在这种情况下,起始索引会调整以返回请求大小的切片:

>>> dynamic_slice(x, (1, 1), (2, 4))
Array([[ 4,  5,  6,  7],
       [ 8,  9, 10, 11]], dtype=int32)