jax.lax.slice

目录

jax.lax.slice#

jax.lax.slice(operand, start_indices, limit_indices, strides=None)[源代码][源代码]#

封装了 XLA 的 Slice 操作符。

参数:
  • operand (ArrayLike) – 一个要切片的数组

  • start_indices (Sequence[int]) – 一系列 operand.ndim 的起始索引。

  • limit_indices (Sequence[int]) – 一系列 operand.ndim 限制索引。

  • strides (Sequence[int] | None) – 一个可选的 operand.ndim 步长序列。

返回:

切片数组

返回类型:

Array

示例

以下是一些简单的二维切片示例:

>>> x = jnp.arange(12).reshape(3, 4)
>>> x
Array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]], dtype=int32)
>>> lax.slice(x, (1, 0), (3, 2))
Array([[4, 5],
       [8, 9]], dtype=int32)
>>> lax.slice(x, (0, 0), (3, 4), (1, 2))
Array([[ 0,  2],
       [ 4,  6],
       [ 8, 10]], dtype=int32)

这两个例子等同于以下 Python 切片语法:

>>> x[1:3, 0:2]
Array([[4, 5],
       [8, 9]], dtype=int32)
>>> x[0:3, 0:4:2]
Array([[ 0,  2],
       [ 4,  6],
       [ 8, 10]], dtype=int32)