jax.lax.动态切片#
- jax.lax.dynamic_slice(operand, start_indices, slice_sizes)[源代码][源代码]#
封装了XLA的 DynamicSlice 操作符。
- 参数:
- 返回:
包含切片内容的数组。
- 返回类型:
示例
这是一个简单的二维动态切片:
>>> x = jnp.arange(12).reshape(3, 4) >>> x Array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32)
>>> dynamic_slice(x, (1, 1), (2, 3)) Array([[ 5, 6, 7], [ 9, 10, 11]], dtype=int32)
注意对于请求的切片超出数组边界的情况,可能会出现令人惊讶的行为;在这种情况下,起始索引会调整以返回请求大小的切片:
>>> dynamic_slice(x, (1, 1), (2, 4)) Array([[ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32)