jax.numpy.append

目录

jax.numpy.append#

jax.numpy.append(arr, values, axis=None)[源代码][源代码]#

返回一个新数组,该数组在原数组的末尾附加了值。

JAX 实现的 numpy.append()

参数:
  • arr (ArrayLike) – 原始数组。

  • values (ArrayLike) – 要附加到数组中的值。values 必须与 arr 具有相同的维数,并且所有维度必须匹配,除了指定的轴之外。

  • axis (int | None) – 要沿其附加值的轴。如果为 None(默认),则 arrvalues 将在附加之前被展平。

返回:

一个新数组,其值附加到 arr 之后。

返回类型:

Array

示例

>>> a = jnp.array([1, 2, 3])
>>> b = jnp.array([4, 5, 6])
>>> jnp.append(a, b)
Array([1, 2, 3, 4, 5, 6], dtype=int32)

沿特定轴追加:

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> b = jnp.array([[5, 6]])
>>> jnp.append(a, b, axis=0)
Array([[1, 2],
       [3, 4],
       [5, 6]], dtype=int32)

沿尾部轴追加:

>>> a = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> b = jnp.array([[7], [8]])
>>> jnp.append(a, b, axis=1)
Array([[1, 2, 3, 7],
       [4, 5, 6, 8]], dtype=int32)