jax.lax.slice#
- jax.lax.slice(operand, start_indices, limit_indices, strides=None)[源代码][源代码]#
封装了 XLA 的 Slice 操作符。
- 参数:
- 返回:
切片数组
- 返回类型:
示例
以下是一些简单的二维切片示例:
>>> x = jnp.arange(12).reshape(3, 4) >>> x Array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32)
>>> lax.slice(x, (1, 0), (3, 2)) Array([[4, 5], [8, 9]], dtype=int32)
>>> lax.slice(x, (0, 0), (3, 4), (1, 2)) Array([[ 0, 2], [ 4, 6], [ 8, 10]], dtype=int32)
这两个例子等同于以下 Python 切片语法:
>>> x[1:3, 0:2] Array([[4, 5], [8, 9]], dtype=int32)
>>> x[0:3, 0:4:2] Array([[ 0, 2], [ 4, 6], [ 8, 10]], dtype=int32)