jax.lax.dynamic_update_slice_in_dim

jax.lax.dynamic_update_slice_in_dim#

jax.lax.dynamic_update_slice_in_dim(operand, update, start_index, axis)[源代码][源代码]#

围绕 dynamic_update_slice() 的便捷包装器,用于在单个 中更新切片。

参数:
  • operand (Array | np.ndarray) – 一个要切片的数组。

  • update (ArrayLike) – 包含要写入 operand 的新值的数组。

  • start_index (ArrayLike) – 一个单一的标量索引

  • axis (int) – 更新轴。

返回:

更新后的数组

返回类型:

Array

示例

>>> x = jnp.zeros(6)
>>> y = jnp.ones(3)
>>> dynamic_update_slice_in_dim(x, y, 2, axis=0)
Array([0., 0., 1., 1., 1., 0.], dtype=float32)

如果更新切片太大以至于无法放入数组中,起始索引将被调整以使其适应:

>>> dynamic_update_slice_in_dim(x, y, 3, axis=0)
Array([0., 0., 0., 1., 1., 1.], dtype=float32)
>>> dynamic_update_slice_in_dim(x, y, 5, axis=0)
Array([0., 0., 0., 1., 1., 1.], dtype=float32)

以下是一个二维切片更新的示例:

>>> x = jnp.zeros((4, 4))
>>> y = jnp.ones((2, 4))
>>> dynamic_update_slice_in_dim(x, y, 1, axis=0)
Array([[0., 0., 0., 0.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [0., 0., 0., 0.]], dtype=float32)

请注意,update 中额外轴的形状不需要与 operand 的相关维度匹配:

>>> y = jnp.ones((2, 3))
>>> dynamic_update_slice_in_dim(x, y, 1, axis=0)
Array([[0., 0., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 0.],
       [0., 0., 0., 0.]], dtype=float32)