jax.lax.dynamic_update_slice_in_dim#
- jax.lax.dynamic_update_slice_in_dim(operand, update, start_index, axis)[源代码][源代码]#
围绕
dynamic_update_slice()
的便捷包装器,用于在单个轴
中更新切片。- 参数:
- 返回:
更新后的数组
- 返回类型:
示例
>>> x = jnp.zeros(6) >>> y = jnp.ones(3) >>> dynamic_update_slice_in_dim(x, y, 2, axis=0) Array([0., 0., 1., 1., 1., 0.], dtype=float32)
如果更新切片太大以至于无法放入数组中,起始索引将被调整以使其适应:
>>> dynamic_update_slice_in_dim(x, y, 3, axis=0) Array([0., 0., 0., 1., 1., 1.], dtype=float32) >>> dynamic_update_slice_in_dim(x, y, 5, axis=0) Array([0., 0., 0., 1., 1., 1.], dtype=float32)
以下是一个二维切片更新的示例:
>>> x = jnp.zeros((4, 4)) >>> y = jnp.ones((2, 4)) >>> dynamic_update_slice_in_dim(x, y, 1, axis=0) Array([[0., 0., 0., 0.], [1., 1., 1., 1.], [1., 1., 1., 1.], [0., 0., 0., 0.]], dtype=float32)
请注意,
update
中额外轴的形状不需要与operand
的相关维度匹配:>>> y = jnp.ones((2, 3)) >>> dynamic_update_slice_in_dim(x, y, 1, axis=0) Array([[0., 0., 0., 0.], [1., 1., 1., 0.], [1., 1., 1., 0.], [0., 0., 0., 0.]], dtype=float32)