jax.numpy.repeat#
- jax.numpy.repeat(a, repeats, axis=None, *, total_repeat_length=None)[源代码][源代码]#
从重复元素构建数组。
JAX 实现的
numpy.repeat()
。- 参数:
a (ArrayLike) – N 维数组
repeats (ArrayLike) – 一维整数数组,指定重复次数。必须与重复轴的长度匹配。
axis (int | None) – 指定
a
的轴,沿此轴构建重复数组。如果为 None(默认),则首先将a
展平。total_repeat_length (int | None) – 对于
jnp.repeat
来说,必须静态指定以使其与jit()
和其他 JAX 变换兼容。如果sum(repeats)
大于指定的total_repeat_length
,剩余的值将被丢弃。如果sum(repeats)
小于total_repeat_length
,最终的值将被重复。
- 返回:
由
a
的重复值构造的数组。- 返回类型:
参见
jax.numpy.tile()
: 重复整个数组,而不是单个值。
示例
沿最后一个轴将每个值重复两次:
>>> a = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.repeat(a, 2, axis=-1) Array([[1, 1, 2, 2], [3, 3, 4, 4]], dtype=int32)
如果未指定
axis
,输入数组将被展平:>>> jnp.repeat(a, 2) Array([1, 1, 2, 2, 3, 3, 4, 4], dtype=int32)
将一个数组传递给
repeats
以重复每个值不同的次数:>>> repeats = jnp.array([2, 3]) >>> jnp.repeat(a, repeats, axis=1) Array([[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]], dtype=int32)
为了在
jit
和其他 JAX 变换中使用repeat
,必须使用total_repeat_length
静态指定输出的大小:>>> jit_repeat = jax.jit(jnp.repeat, static_argnames=['axis', 'total_repeat_length']) >>> jit_repeat(a, repeats, axis=1, total_repeat_length=5) Array([[1, 1, 2, 2, 2], [3, 3, 4, 4, 4]], dtype=int32)
如果 total_repeat_length 小于
sum(repeats)
,结果将被截断:>>> jit_repeat(a, repeats, axis=1, total_repeat_length=4) Array([[1, 1, 2, 2], [3, 3, 4, 4]], dtype=int32)
如果它更大,那么额外的条目将被填充为最终值:
>>> jit_repeat(a, repeats, axis=1, total_repeat_length=7) Array([[1, 1, 2, 2, 2, 2, 2], [3, 3, 4, 4, 4, 4, 4]], dtype=int32)