jax.numpy.put#
- jax.numpy.put(a, ind, v, mode=None, *, inplace=True)[源代码][源代码]#
将元素放入指定索引的数组中。
JAX 实现的
numpy.put()
。The semantics of
numpy.put()
are to modify arrays in-place, which is not possible for JAX’s immutable arrays. The JAX version returns a modified copy of the input, and adds theinplace
parameter which must be set to False` by the user as a reminder of this API difference.- 参数:
- 返回:
更新了指定条目的
a
副本。- 返回类型:
参见
jax.numpy.place()
: 通过布尔掩码将元素放入数组中。jax.numpy.ndarray.at()
: 使用NumPy风格索引的数组更新。jax.numpy.take()
: 从数组中提取给定索引处的值。
示例
>>> x = jnp.zeros(5, dtype=int) >>> indices = jnp.array([0, 2, 4]) >>> values = jnp.array([10, 20, 30]) >>> jnp.put(x, indices, values, inplace=False) Array([10, 0, 20, 0, 30], dtype=int32)
这等同于以下
jax.numpy.ndarray.at
索引语法:>>> x.at[indices].set(values) Array([10, 0, 20, 0, 30], dtype=int32)
处理越界索引有两种模式。默认情况下,它们会被裁剪:
>>> indices = jnp.array([0, 2, 6]) >>> jnp.put(x, indices, values, inplace=False, mode='clip') Array([10, 0, 20, 0, 30], dtype=int32)
或者,它们可以被包裹到数组的开始:
>>> jnp.put(x, indices, values, inplace=False, mode='wrap') Array([10, 30, 20, 0, 0], dtype=int32)
对于N维输入,索引指的是展平后的数组:
>>> x = jnp.zeros((3, 5), dtype=int) >>> indices = jnp.array([0, 7, 14]) >>> jnp.put(x, indices, values, inplace=False) Array([[10, 0, 0, 0, 0], [ 0, 0, 20, 0, 0], [ 0, 0, 0, 0, 30]], dtype=int32)