jax.numpy.put

目录

jax.numpy.put#

jax.numpy.put(a, ind, v, mode=None, *, inplace=True)[源代码][源代码]#

将元素放入指定索引的数组中。

JAX 实现的 numpy.put()

The semantics of numpy.put() are to modify arrays in-place, which is not possible for JAX’s immutable arrays. The JAX version returns a modified copy of the input, and adds the inplace parameter which must be set to False` by the user as a reminder of this API difference.

参数:
  • a (ArrayLike) – 将放置值的数组。

  • ind (ArrayLike) – 在展平数组上放置值的索引数组。

  • v (ArrayLike) – 要放入数组的值的数组。

  • mode (str | None) – 指定如何处理越界索引的字符串。支持的值: - "clip"``(默认):将越界索引剪切到最后一个索引。 - ``"wrap":将越界索引环绕到数组的开始。

  • inplace (bool) – 必须设置为 False 以表明输入未就地修改,而是返回了一个修改后的副本。

返回:

更新了指定条目的 a 副本。

返回类型:

Array

参见

示例

>>> x = jnp.zeros(5, dtype=int)
>>> indices = jnp.array([0, 2, 4])
>>> values = jnp.array([10, 20, 30])
>>> jnp.put(x, indices, values, inplace=False)
Array([10,  0, 20,  0, 30], dtype=int32)

这等同于以下 jax.numpy.ndarray.at 索引语法:

>>> x.at[indices].set(values)
Array([10,  0, 20,  0, 30], dtype=int32)

处理越界索引有两种模式。默认情况下,它们会被裁剪:

>>> indices = jnp.array([0, 2, 6])
>>> jnp.put(x, indices, values, inplace=False, mode='clip')
Array([10,  0, 20,  0, 30], dtype=int32)

或者,它们可以被包裹到数组的开始:

>>> jnp.put(x, indices, values, inplace=False, mode='wrap')
Array([10,  30, 20,  0, 0], dtype=int32)

对于N维输入,索引指的是展平后的数组:

>>> x = jnp.zeros((3, 5), dtype=int)
>>> indices = jnp.array([0, 7, 14])
>>> jnp.put(x, indices, values, inplace=False)
Array([[10,  0,  0,  0,  0],
       [ 0,  0, 20,  0,  0],
       [ 0,  0,  0,  0, 30]], dtype=int32)