jax.numpy.repeat

目录

jax.numpy.repeat#

jax.numpy.repeat(a, repeats, axis=None, *, total_repeat_length=None)[源代码][源代码]#

从重复元素构建数组。

JAX 实现的 numpy.repeat()

参数:
  • a (ArrayLike) – N 维数组

  • repeats (ArrayLike) – 一维整数数组,指定重复次数。必须与重复轴的长度匹配。

  • axis (int | None) – 指定 a 的轴,沿此轴构建重复数组。如果为 None(默认),则首先将 a 展平。

  • total_repeat_length (int | None) – 对于 jnp.repeat 来说,必须静态指定以使其与 jit() 和其他 JAX 变换兼容。如果 sum(repeats) 大于指定的 total_repeat_length,剩余的值将被丢弃。如果 sum(repeats) 小于 total_repeat_length,最终的值将被重复。

返回:

a 的重复值构造的数组。

返回类型:

Array

参见

示例

沿最后一个轴将每个值重复两次:

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> jnp.repeat(a, 2, axis=-1)
Array([[1, 1, 2, 2],
       [3, 3, 4, 4]], dtype=int32)

如果未指定 axis ,输入数组将被展平:

>>> jnp.repeat(a, 2)
Array([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32)

将一个数组传递给 repeats 以重复每个值不同的次数:

>>> repeats = jnp.array([2, 3])
>>> jnp.repeat(a, repeats, axis=1)
Array([[1, 1, 2, 2, 2],
       [3, 3, 4, 4, 4]], dtype=int32)

为了在 jit 和其他 JAX 变换中使用 repeat ,必须使用 total_repeat_length 静态指定输出的大小:

>>> jit_repeat = jax.jit(jnp.repeat, static_argnames=['axis', 'total_repeat_length'])
>>> jit_repeat(a, repeats, axis=1, total_repeat_length=5)
Array([[1, 1, 2, 2, 2],
       [3, 3, 4, 4, 4]], dtype=int32)

如果 total_repeat_length 小于 sum(repeats),结果将被截断:

>>> jit_repeat(a, repeats, axis=1, total_repeat_length=4)
Array([[1, 1, 2, 2],
       [3, 3, 4, 4]], dtype=int32)

如果它更大,那么额外的条目将被填充为最终值:

>>> jit_repeat(a, repeats, axis=1, total_repeat_length=7)
Array([[1, 1, 2, 2, 2, 2, 2],
       [3, 3, 4, 4, 4, 4, 4]], dtype=int32)