jax.scipy.fft.dct

目录

jax.scipy.fft.dct#

jax.scipy.fft.dct(x, type=2, n=None, axis=-1, norm=None)[源代码][源代码]#

计算输入的离散余弦变换

JAX 实现的 scipy.fft.dct()

参数:
  • x (Array) – 数组

  • type (int) – 整数,默认值 = 2。目前仅支持类型 2。

  • n (int | None) – 整数,默认值为 x.shape[axis]。变换的长度。如果大于 x.shape[axis],输入将被零填充;如果小于,输入将被截断。

  • axis (int) – 整数,默认=-1。将沿此轴执行离散余弦变换。

  • norm (str | None) – 字符串。归一化模式:可以是 [None, "backward", "ortho"] 之一。默认是 None,相当于 "backward"

返回:

包含 x 的离散余弦变换的数组

返回类型:

Array

参见

示例

>>> x = jax.random.normal(jax.random.key(0), (3, 3))
>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jax.scipy.fft.dct(x))
[[-0.58 -0.33 -1.08]
 [-0.88 -1.01 -1.79]
 [-1.06 -2.43  1.24]]

n 小于 x.shape[axis]

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jax.scipy.fft.dct(x, n=2))
[[-0.22 -0.9 ]
 [-0.57 -1.68]
 [-2.52 -0.11]]

n 小于 x.shape[axis]axis=0

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jax.scipy.fft.dct(x, n=2, axis=0))
[[-2.22  1.43 -0.67]
 [ 0.52 -0.26 -0.04]]

n 大于 x.shape[axis]axis=1

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jax.scipy.fft.dct(x, n=4, axis=1))
[[-0.58 -0.35 -0.64 -1.11]
 [-0.88 -0.9  -1.46 -1.68]
 [-1.06 -2.25 -1.15  1.93]]