# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
from functools import partial
import math
from jax import lax
import jax.numpy as jnp
from jax._src.util import canonicalize_axis
from jax._src.numpy.util import promote_dtypes_complex
from jax._src.typing import Array
def _W4(N: int, k: Array) -> Array:
N_arr, k = promote_dtypes_complex(N, k)
return jnp.exp(-.5j * jnp.pi * k / N_arr)
def _dct_interleave(x: Array, axis: int) -> Array:
v0 = lax.slice_in_dim(x, None, None, 2, axis)
v1 = lax.rev(lax.slice_in_dim(x, 1, None, 2, axis), (axis,))
return lax.concatenate([v0, v1], axis)
def _dct_ortho_norm(out: Array, axis: int) -> Array:
factor = lax.concatenate([lax.full((1,), 4, out.dtype), lax.full((out.shape[axis] - 1,), 2, out.dtype)], 0)
factor = lax.expand_dims(factor, [a for a in range(out.ndim) if a != axis])
return out / lax.sqrt(factor * out.shape[axis])
# Implementation based on
# John Makhoul: A Fast Cosine Transform in One and Two Dimensions (1980)
[文档]
def dct(x: Array, type: int = 2, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
"""Computes the discrete cosine transform of the input
JAX implementation of :func:`scipy.fft.dct`.
Args:
x: array
type: integer, default = 2. Currently only type 2 is supported.
n: integer, default = x.shape[axis]. The length of the transform.
If larger than ``x.shape[axis]``, the input will be zero-padded, if
smaller, the input will be truncated.
axis: integer, default=-1. The axis along which the dct will be performed.
norm: string. The normalization mode: one of ``[None, "backward", "ortho"]``.
The default is ``None``, which is equivalent to ``"backward"``.
Returns:
array containing the discrete cosine transform of x
See Also:
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
- :func:`jax.scipy.fft.idct`: inverse DCT
- :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT
Examples:
>>> x = jax.random.normal(jax.random.key(0), (3, 3))
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.dct(x))
[[-0.58 -0.33 -1.08]
[-0.88 -1.01 -1.79]
[-1.06 -2.43 1.24]]
When ``n`` smaller than ``x.shape[axis]``
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.dct(x, n=2))
[[-0.22 -0.9 ]
[-0.57 -1.68]
[-2.52 -0.11]]
When ``n`` smaller than ``x.shape[axis]`` and ``axis=0``
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.dct(x, n=2, axis=0))
[[-2.22 1.43 -0.67]
[ 0.52 -0.26 -0.04]]
When ``n`` larger than ``x.shape[axis]`` and ``axis=1``
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.dct(x, n=4, axis=1))
[[-0.58 -0.35 -0.64 -1.11]
[-0.88 -0.9 -1.46 -1.68]
[-1.06 -2.25 -1.15 1.93]]
"""
if type != 2:
raise NotImplementedError('Only DCT type 2 is implemented.')
if norm is not None and norm not in ['backward', 'ortho']:
raise ValueError(f"jax.scipy.fft.dct: {norm=!r} is not implemented")
axis = canonicalize_axis(axis, x.ndim)
if n is not None:
x = lax.pad(x, jnp.array(0, x.dtype),
[(0, n - x.shape[axis] if a == axis else 0, 0)
for a in range(x.ndim)])
N = x.shape[axis]
v = _dct_interleave(x, axis)
V = jnp.fft.fft(v, axis=axis)
k = lax.expand_dims(jnp.arange(N, dtype=V.real.dtype), [a for a in range(x.ndim) if a != axis])
out = V * _W4(N, k)
out = 2 * out.real
if norm == 'ortho':
out = _dct_ortho_norm(out, axis)
return out
def _dct2(x: Array, axes: Sequence[int], norm: str | None) -> Array:
axis1, axis2 = map(partial(canonicalize_axis, num_dims=x.ndim), axes)
N1, N2 = x.shape[axis1], x.shape[axis2]
v = _dct_interleave(_dct_interleave(x, axis1), axis2)
V = jnp.fft.fftn(v, axes=axes)
k1 = lax.expand_dims(jnp.arange(N1, dtype=V.dtype),
[a for a in range(x.ndim) if a != axis1])
k2 = lax.expand_dims(jnp.arange(N2, dtype=V.dtype),
[a for a in range(x.ndim) if a != axis2])
out = _W4(N1, k1) * (_W4(N2, k2) * V + _W4(N2, -k2) * jnp.roll(jnp.flip(V, axis=axis2), shift=1, axis=axis2))
out = 2 * out.real
if norm == 'ortho':
return _dct_ortho_norm(_dct_ortho_norm(out, axis1), axis2)
return out
[文档]
def dctn(x: Array, type: int = 2,
s: Sequence[int] | None=None,
axes: Sequence[int] | None = None,
norm: str | None = None) -> Array:
"""Computes the multidimensional discrete cosine transform of the input
JAX implementation of :func:`scipy.fft.dctn`.
Args:
x: array
type: integer, default = 2. Currently only type 2 is supported.
s: integer or sequence of integers. Specifies the shape of the result. If not
specified, it will default to the shape of ``x`` along the specified ``axes``.
axes: integer or sequence of integers. Specifies the axes along which the
transform will be computed.
norm: string. The normalization mode: one of ``[None, "backward", "ortho"]``.
The default is ``None``, which is equivalent to ``"backward"``.
Returns:
array containing the discrete cosine transform of x
See Also:
- :func:`jax.scipy.fft.dct`: one-dimensional DCT
- :func:`jax.scipy.fft.idct`: one-dimensional inverse DCT
- :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT
Examples:
``jax.scipy.fft.dctn`` computes the transform along both the axes by default
when ``axes`` argument is ``None``.
>>> x = jax.random.normal(jax.random.key(0), (3, 3))
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.dctn(x))
[[-5.04 -7.54 -3.26]
[ 0.83 3.64 -4.03]
[ 0.12 -0.73 3.74]]
When ``s=[2]``, dimension of the transform along ``axis 0`` will be ``2``
and dimension along ``axis 1`` will be same as that of input.
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.dctn(x, s=[2]))
[[-2.92 -2.68 -5.74]
[ 0.42 0.97 1. ]]
When ``s=[2]`` and ``axes=[1]``, dimension of the transform along ``axis 1`` will
be ``2`` and dimension along ``axis 0`` will be same as that of input.
Also when ``axes=[1]``, transform will be computed only along ``axis 1``.
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.dctn(x, s=[2], axes=[1]))
[[-0.22 -0.9 ]
[-0.57 -1.68]
[-2.52 -0.11]]
When ``s=[2, 4]``, shape of the transform will be ``(2, 4)``.
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.dctn(x, s=[2, 4]))
[[-2.92 -2.49 -4.21 -5.57]
[ 0.42 0.79 1.16 0.8 ]]
"""
if type != 2:
raise NotImplementedError('Only DCT type 2 is implemented.')
if norm is not None and norm not in ['backward', 'ortho']:
raise ValueError(f"jax.scipy.fft.dctn: {norm=!r} is not implemented")
if axes is None:
axes = range(x.ndim)
if len(axes) == 1:
return dct(x, n=s[0] if s is not None else None, axis=axes[0], norm=norm)
if s is not None:
ns = dict(zip(axes, s))
pads = [(0, ns[a] - x.shape[a] if a in ns else 0, 0) for a in range(x.ndim)]
x = lax.pad(x, jnp.array(0, x.dtype), pads)
if len(axes) == 2:
return _dct2(x, axes=axes, norm=norm)
# compose high-D DCTs from 2D and 1D DCTs:
for axes_block in [axes[i:i+2] for i in range(0, len(axes), 2)]:
x = dctn(x, axes=axes_block, norm=norm)
return x
[文档]
def idct(x: Array, type: int = 2, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
"""Computes the inverse discrete cosine transform of the input
JAX implementation of :func:`scipy.fft.idct`.
Args:
x: array
type: integer, default = 2. Currently only type 2 is supported.
n: integer, default = x.shape[axis]. The length of the transform.
If larger than ``x.shape[axis]``, the input will be zero-padded, if
smaller, the input will be truncated.
axis: integer, default=-1. The axis along which the dct will be performed.
norm: string. The normalization mode: one of ``[None, "backward", "ortho"]``.
The default is ``None``, which is equivalent to ``"backward"``.
Returns:
array containing the inverse discrete cosine transform of x
See Also:
- :func:`jax.scipy.fft.dct`: DCT
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
- :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT
Examples:
>>> x = jax.random.normal(jax.random.key(0), (3, 3))
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.idct(x))
[[-0.02 -0. -0.17]
[-0.02 -0.07 -0.28]
[-0.16 -0.36 0.18]]
When ``n`` smaller than ``x.shape[axis]``
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.idct(x, n=2))
[[ 0. -0.19]
[-0.03 -0.34]
[-0.38 0.04]]
When ``n`` smaller than ``x.shape[axis]`` and ``axis=0``
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.idct(x, n=2, axis=0))
[[-0.35 0.23 -0.1 ]
[ 0.17 -0.09 0.01]]
When ``n`` larger than ``x.shape[axis]`` and ``axis=0``
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.idct(x, n=4, axis=0))
[[-0.34 0.03 0.07]
[ 0. 0.18 -0.17]
[ 0.14 0.09 -0.14]
[ 0. -0.18 0.14]]
``jax.scipy.fft.idct`` can be used to reconstruct ``x`` from the result
of ``jax.scipy.fft.dct``
>>> x_dct = jax.scipy.fft.dct(x)
>>> jnp.allclose(x, jax.scipy.fft.idct(x_dct))
Array(True, dtype=bool)
"""
if type != 2:
raise NotImplementedError('Only DCT type 2 is implemented.')
if norm is not None and norm not in ['backward', 'ortho']:
raise ValueError(f"jax.scipy.fft.idct: {norm=!r} is not implemented")
axis = canonicalize_axis(axis, x.ndim)
if n is not None:
x = lax.pad(x, jnp.array(0, x.dtype),
[(0, n - x.shape[axis] if a == axis else 0, 0)
for a in range(x.ndim)])
N = x.shape[axis]
x = x.astype(jnp.float32)
if norm is None or norm == 'backward':
x = _dct_ortho_norm(x, axis)
x = _dct_ortho_norm(x, axis)
k = lax.expand_dims(jnp.arange(N, dtype=jnp.float32), [a for a in range(x.ndim) if a != axis])
# everything is complex from here...
w4 = _W4(N,k)
x = x.astype(w4.dtype)
x = x / (_W4(N, k))
x = x * 2 * N
x = jnp.fft.ifft(x, axis=axis)
# convert back to reals..
out = _dct_deinterleave(x.real, axis)
return out
[文档]
def idctn(x: Array, type: int = 2,
s: Sequence[int] | None=None,
axes: Sequence[int] | None = None,
norm: str | None = None) -> Array:
"""Computes the multidimensional inverse discrete cosine transform of the input
JAX implementation of :func:`scipy.fft.idctn`.
Args:
x: array
type: integer, default = 2. Currently only type 2 is supported.
s: integer or sequence of integers. Specifies the shape of the result. If not
specified, it will default to the shape of ``x`` along the specified ``axes``.
axes: integer or sequence of integers. Specifies the axes along which the
transform will be computed.
norm: string. The normalization mode: one of ``[None, "backward", "ortho"]``.
The default is ``None``, which is equivalent to ``"backward"``.
Returns:
array containing the inverse discrete cosine transform of x
See Also:
- :func:`jax.scipy.fft.dct`: one-dimensional DCT
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
- :func:`jax.scipy.fft.idct`: one-dimensional inverse DCT
Examples:
``jax.scipy.fft.idctn`` computes the transform along both the axes by default
when ``axes`` argument is ``None``.
>>> x = jax.random.normal(jax.random.key(0), (3, 3))
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.idctn(x))
[[-0.03 -0.08 -0.08]
[ 0.05 0.12 -0.09]
[-0.02 -0.04 0.08]]
When ``s=[2]``, dimension of the transform along ``axis 0`` will be ``2``
and dimension along ``axis 1`` will be the same as that of input.
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.idctn(x, s=[2]))
[[-0.01 -0.03 -0.14]
[ 0. 0.03 0.06]]
When ``s=[2]`` and ``axes=[1]``, dimension of the transform along ``axis 1`` will
be ``2`` and dimension along ``axis 0`` will be same as that of input.
Also when ``axes=[1]``, transform will be computed only along ``axis 1``.
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.idctn(x, s=[2], axes=[1]))
[[ 0. -0.19]
[-0.03 -0.34]
[-0.38 0.04]]
When ``s=[2, 4]``, shape of the transform will be ``(2, 4)``
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.idctn(x, s=[2, 4]))
[[-0.01 -0.01 -0.05 -0.11]
[ 0. 0.01 0.03 0.04]]
``jax.scipy.fft.idctn`` can be used to reconstruct ``x`` from the result
of ``jax.scipy.fft.dctn``
>>> x_dctn = jax.scipy.fft.dctn(x)
>>> jnp.allclose(x, jax.scipy.fft.idctn(x_dctn))
Array(True, dtype=bool)
"""
if type != 2:
raise NotImplementedError('Only DCT type 2 is implemented.')
if norm is not None and norm not in ['backward', 'ortho']:
raise ValueError(f"jax.scipy.fft.idctn: {norm=!r} is not implemented")
if axes is None:
axes = range(x.ndim)
if len(axes) == 1:
return idct(x, n=s[0] if s is not None else None, axis=axes[0], norm=norm)
if s is not None:
ns = dict(zip(axes, s))
pads = [(0, ns[a] - x.shape[a] if a in ns else 0, 0) for a in range(x.ndim)]
x = lax.pad(x, jnp.array(0, x.dtype), pads)
# compose high-D DCTs from 1D DCTs:
for axis in axes:
x = idct(x, axis=axis, norm=norm)
return x
def _dct_deinterleave(x: Array, axis: int) -> Array:
empty_slice = slice(None, None, None)
ix0 = tuple(
slice(None, math.ceil(x.shape[axis]/2), 1) if i == axis else empty_slice
for i in range(len(x.shape)))
ix1 = tuple(
slice(math.ceil(x.shape[axis]/2), None, 1) if i == axis else empty_slice
for i in range(len(x.shape)))
v0 = x[ix0]
v1 = lax.rev(x[ix1], (axis,))
out = jnp.zeros(x.shape, dtype=x.dtype)
evens = tuple(
slice(None, None, 2) if i == axis else empty_slice for i in range(len(x.shape)))
odds = tuple(
slice(1, None, 2) if i == axis else empty_slice for i in range(len(x.shape)))
out = out.at[evens].set(v0)
out = out.at[odds].set(v1)
return out