jax.numpy.convolve

目录

jax.numpy.convolve#

jax.numpy.convolve(a, v, mode='full', *, precision=None, preferred_element_type=None)[源代码][源代码]#

两个一维数组的卷积。

JAX 实现的 numpy.convolve()

一维数组的卷积定义为:

\[c_k = \sum_j a_{k - j} v_j\]
参数:
  • a (ArrayLike) – 卷积的左手输入。必须满足 a.ndim == 1

  • v (ArrayLike) – 卷积的右手输入。必须满足 v.ndim == 1

  • mode (str) – 控制输出的大小。可用的操作有: * "full": (默认) 输出输入的完整卷积。 * "same": 返回 "full" 输出的中心部分,其大小与 a 相同。 * "valid": 返回 "full" 输出中不依赖于数组边缘填充的部分。

  • precision (PrecisionLike) – 指定计算的精度。有关可用值的描述,请参阅 jax.lax.Precision

  • preferred_element_type (DTypeLike | None) – 一个数据类型,指示将结果累积并返回该数据类型的结果。默认是 None,这意味着输入类型的默认累积类型。

返回:

包含卷积结果的数组。

返回类型:

Array

参见

示例

几个一维卷积的例子:

>>> x = jnp.array([1, 2, 3, 2, 1])
>>> y = jnp.array([4, 1, 2])

jax.numpy.convolve 默认返回使用边缘隐式零填充的完整卷积:

>>> jnp.convolve(x, y)
Array([ 4.,  9., 16., 15., 12.,  5.,  2.], dtype=float32)

指定 mode = 'same' 返回一个与第一个输入相同大小的居中卷积:

>>> jnp.convolve(x, y, mode='same')
Array([ 9., 16., 15., 12.,  5.], dtype=float32)

指定 mode = 'valid' 仅返回两个数组完全重叠的部分:

>>> jnp.convolve(x, y, mode='valid')
Array([16., 15., 12.], dtype=float32)

对于复数输入:

>>> x1 = jnp.array([3+1j, 2, 4-3j])
>>> y1 = jnp.array([1, 2-3j, 4+5j])
>>> jnp.convolve(x1, y1)
Array([ 3. +1.j, 11. -7.j, 15.+10.j,  7. -8.j, 31. +8.j], dtype=complex64)