# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from functools import partial
import operator
import numpy as np
from jax import jit
from jax import lax
from jax._src import dtypes
from jax._src import core
from jax._src.numpy.lax_numpy import (
arange, argmin, array, asarray, atleast_1d, concatenate, convolve,
diag, dot, finfo, full, ones, outer, roll, trim_zeros,
trim_zeros_tol, vander, zeros)
from jax._src.numpy.ufuncs import maximum, true_divide, sqrt
from jax._src.numpy.reductions import all
from jax._src.numpy import linalg
from jax._src.numpy.util import (
check_arraylike, promote_dtypes, promote_dtypes_inexact, _where)
from jax._src.typing import Array, ArrayLike
@jit
def _roots_no_zeros(p: Array) -> Array:
# build companion matrix and find its eigenvalues (the roots)
if p.size < 2:
return array([], dtype=dtypes.to_complex_dtype(p.dtype))
A = diag(ones((p.size - 2,), p.dtype), -1)
A = A.at[0, :].set(-p[1:] / p[0])
return linalg.eigvals(A)
@jit
def _roots_with_zeros(p: Array, num_leading_zeros: Array | int) -> Array:
# Avoid lapack errors when p is all zero
p = _where(len(p) == num_leading_zeros, 1.0, p)
# Roll any leading zeros to the end & compute the roots
roots = _roots_no_zeros(roll(p, -num_leading_zeros))
# Sort zero roots to the end.
roots = lax.sort_key_val(roots == 0, roots)[1]
# Set roots associated with num_leading_zeros to NaN
return _where(arange(roots.size) < roots.size - num_leading_zeros, roots, complex(np.nan, np.nan))
[文档]
def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array:
r"""Returns the roots of a polynomial given the coefficients ``p``.
JAX implementations of :func:`numpy.roots`.
Args:
p: Array of polynomial coefficients having rank-1.
strip_zeros : bool, default=True. If True, then leading zeros in the
coefficients will be stripped, similar to :func:`numpy.roots`. If set to
False, leading zeros will not be stripped, and undefined roots will be
represented by NaN values in the function output. ``strip_zeros`` must be
set to ``False`` for the function to be compatible with :func:`jax.jit` and
other JAX transformations.
Returns:
An array containing the roots of the polynomial.
Note:
Unlike ``np.roots`` of this function, the ``jnp.roots`` returns the roots
in a complex array regardless of the values of the roots.
See Also:
- :func:`jax.numpy.poly`: Finds the polynomial coefficients of the given
sequence of roots.
- :func:`jax.numpy.polyfit`: Least squares polynomial fit to data.
- :func:`jax.numpy.polyval`: Evaluate a polynomial at specific values.
Examples:
>>> coeffs = jnp.array([0, 1, 2])
The default behavior matches numpy and strips leading zeros:
>>> jnp.roots(coeffs)
Array([-2.+0.j], dtype=complex64)
With ``strip_zeros=False``, extra roots are set to NaN:
>>> jnp.roots(coeffs, strip_zeros=False)
Array([-2. +0.j, nan+nanj], dtype=complex64)
"""
check_arraylike("roots", p)
p_arr = atleast_1d(promote_dtypes_inexact(p)[0])
del p
if p_arr.ndim != 1:
raise ValueError("Input must be a rank-1 array.")
if p_arr.size < 2:
return array([], dtype=dtypes.to_complex_dtype(p_arr.dtype))
num_leading_zeros = _where(all(p_arr == 0), len(p_arr), argmin(p_arr == 0))
if strip_zeros:
num_leading_zeros = core.concrete_or_error(int, num_leading_zeros,
"The error occurred in the jnp.roots() function. To use this within a "
"JIT-compiled context, pass strip_zeros=False, but be aware that leading zeros "
"will result in some returned roots being set to NaN.")
return _roots_no_zeros(p_arr[num_leading_zeros:])
else:
return _roots_with_zeros(p_arr, num_leading_zeros)
[文档]
@partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov'))
def polyfit(x: ArrayLike, y: ArrayLike, deg: int, rcond: float | None = None,
full: bool = False, w: ArrayLike | None = None, cov: bool = False
) -> Array | tuple[Array, ...]:
r"""Least squares polynomial fit to data.
Jax implementation of :func:`numpy.polyfit`.
Given a set of data points ``(x, y)`` and degree of polynomial ``deg``, the
function finds a polynomial equation of the form:
.. math::
y = p(x) = p[0] x^{deg} + p[1] x^{deg - 1} + ... + p[deg]
Args:
x: Array of data points of shape ``(M,)``.
y: Array of data points of shape ``(M,)`` or ``(M, K)``.
deg: Degree of the polynomials. It must be specified statically.
rcond: Relative condition number of the fit. Default value is ``len(x) * eps``.
It must be specified statically.
full: Switch that controls the return value. Default is ``False`` which
restricts the return value to the array of polynomail coefficients ``p``.
If ``True``, the function returns a tuple ``(p, resids, rank, s, rcond)``.
It must be specified statically.
w: Array of weights of shape ``(M,)``. If None, all data points are considered
to have equal weight. If not None, the weight :math:`w_i` is applied to the
unsquared residual of :math:`y_i - \widehat{y}_i` at :math:`x_i`, where
:math:`\widehat{y}_i` is the fitted value of :math:`y_i`. Default is None.
cov: Boolean or string. If ``True``, returns the covariance matrix scaled
by ``resids/(M-deg-1)`` along with ploynomial coefficients. If
``cov='unscaled'``, returns the unscaaled version of covariance matrix.
Default is ``False``. ``cov`` is ignored if ``full=True``. It must be
specified statically.
Returns:
- An array polynomial coefficients ``p`` if ``full=False`` and ``cov=False``.
- A tuple of arrays ``(p, resids, rank, s, rcond)`` if ``full=True``. Where
- ``p`` is an array of shape ``(M,)`` or ``(M, K)`` containing the polynomial
coefficients.
- ``resids`` is the sum of squared residual of shape () or (K,).
- ``rank`` is the rank of the matrix ``x``.
- ``s`` is the singular values of the matrix ``x``.
- ``rcond`` as the array.
- A tuple of arrays ``(p, C)`` if ``full=False`` and ``cov=True``. Where
- ``p`` is an array of shape ``(M,)`` or ``(M, K)`` containing the polynomial
coefficients.
- ``C`` is the covariance matrix of polynomial coefficients of shape
``(deg + 1, deg + 1)`` or ``(deg + 1, deg + 1, 1)``.
Note:
Unlike :func:`numpy.polyfit` implementation of polyfit, :func:`jax.numpy.polyfit`
will not warn on rank reduction, which indicates an ill conditioned matrix.
See Also:
- :func:`jax.numpy.poly`: Finds the polynomial coefficients of the given
sequence of roots.
- :func:`jax.numpy.polyval`: Evaluate a polynomial at specific values.
- :func:`jax.numpy.roots`: Computes the roots of a polynomial for given
coefficients.
Examples:
>>> x = jnp.array([3., 6., 9., 4.])
>>> y = jnp.array([[0, 1, 2],
... [2, 5, 7],
... [8, 4, 9],
... [1, 6, 3]])
>>> p = jnp.polyfit(x, y, 2)
>>> with jnp.printoptions(precision=2, suppress=True):
... print(p)
[[ 0.2 -0.35 -0.14]
[-1.17 4.47 2.96]
[ 1.95 -8.21 -5.93]]
If ``full=True``, returns a tuple of arrays as follows:
>>> p, resids, rank, s, rcond = jnp.polyfit(x, y, 2, full=True)
>>> with jnp.printoptions(precision=2, suppress=True):
... print("Polynomial Coefficients:", "\n", p, "\n",
... "Residuals:", resids, "\n",
... "Rank:", rank, "\n",
... "s:", s, "\n",
... "rcond:", rcond)
Polynomial Coefficients:
[[ 0.2 -0.35 -0.14]
[-1.17 4.47 2.96]
[ 1.95 -8.21 -5.93]]
Residuals: [0.37 5.94 0.61]
Rank: 3
s: [1.67 0.47 0.04]
rcond: 4.7683716e-07
If ``cov=True`` and ``full=False``, returns a tuple of arrays having
polynomial coefficients and covariance matrix.
>>> p, C = jnp.polyfit(x, y, 2, cov=True)
>>> p.shape, C.shape
((3, 3), (3, 3, 1))
"""
if w is None:
check_arraylike("polyfit", x, y)
else:
check_arraylike("polyfit", x, y, w)
deg = core.concrete_or_error(int, deg, "deg must be int")
order = deg + 1
# check arguments
x_arr, y_arr = asarray(x), asarray(y)
del x, y
if deg < 0:
raise ValueError("expected deg >= 0")
if x_arr.ndim != 1:
raise TypeError("expected 1D vector for x")
if x_arr.size == 0:
raise TypeError("expected non-empty vector for x")
if y_arr.ndim < 1 or y_arr.ndim > 2:
raise TypeError("expected 1D or 2D array for y")
if x_arr.shape[0] != y_arr.shape[0]:
raise TypeError("expected x and y to have same length")
# set rcond
if rcond is None:
rcond = len(x_arr) * finfo(x_arr.dtype).eps
rcond = core.concrete_or_error(float, rcond, "rcond must be float")
# set up least squares equation for powers of x
lhs = vander(x_arr, order)
rhs = y_arr
# apply weighting
if w is not None:
w, = promote_dtypes_inexact(w)
w_arr = asarray(w)
if w_arr.ndim != 1:
raise TypeError("expected a 1-d array for weights")
if w_arr.shape[0] != y_arr.shape[0]:
raise TypeError("expected w and y to have the same length")
lhs *= w_arr[:, np.newaxis]
if rhs.ndim == 2:
rhs *= w_arr[:, np.newaxis]
else:
rhs *= w_arr
# scale lhs to improve condition number and solve
scale = sqrt((lhs*lhs).sum(axis=0))
lhs /= scale[np.newaxis,:]
c, resids, rank, s = linalg.lstsq(lhs, rhs, rcond)
c = (c.T/scale).T # broadcast scale coefficients
if full:
return c, resids, rank, s, asarray(rcond)
elif cov:
Vbase = linalg.inv(dot(lhs.T, lhs))
Vbase /= outer(scale, scale)
if cov == "unscaled":
fac = 1
else:
if len(x_arr) <= order:
raise ValueError("the number of data points must exceed order "
"to scale the covariance matrix")
fac = resids / (len(x_arr) - order)
fac = fac[0] #making np.array() of shape (1,) to int
if y_arr.ndim == 1:
return c, Vbase * fac
else:
return c, Vbase[:, :, np.newaxis] * fac
else:
return c
[文档]
@jit
def poly(seq_of_zeros: ArrayLike) -> Array:
r"""Returns the coefficients of a polynomial for the given sequence of roots.
JAX implementation of :func:`numpy.poly`.
Args:
seq_of_zeros: A scalar or an array of roots of the polynomial of shape ``(M,)``
or ``(M, M)``.
Returns:
An array containing the coefficients of the polynomial. The dtype of the
output is always promoted to inexact.
Note:
:func:`jax.numpy.poly` differs from :func:`numpy.poly`:
- When the input is a scalar, ``np.poly`` raises a ``TypeError``, whereas
``jnp.poly`` treats scalars the same as length-1 arrays.
- For complex-valued or square-shaped inputs, ``jnp.poly`` always returns
complex coefficients, whereas ``np.poly`` may return real or complex
depending on their values.
See also:
- :func:`jax.numpy.polyfit`: Least squares polynomial fit.
- :func:`jax.numpy.polyval`: Evaluate a polynomial at specific values.
- :func:`jax.numpy.roots`: Computes the roots of a polynomial for given
coefficients.
Example:
Scalar inputs:
>>> jnp.poly(1)
Array([ 1., -1.], dtype=float32)
Input array with integer values:
>>> x = jnp.array([1, 2, 3])
>>> jnp.poly(x)
Array([ 1., -6., 11., -6.], dtype=float32)
Input array with complex conjugates:
>>> x = jnp.array([2, 1+2j, 1-2j])
>>> jnp.poly(x)
Array([ 1.+0.j, -4.+0.j, 9.+0.j, -10.+0.j], dtype=complex64)
Input array as square matrix with real valued inputs:
>>> x = jnp.array([[2, 1, 5],
... [3, 4, 7],
... [1, 3, 5]])
>>> jnp.round(jnp.poly(x))
Array([ 1.+0.j, -11.-0.j, 9.+0.j, -15.+0.j], dtype=complex64)
"""
check_arraylike('poly', seq_of_zeros)
seq_of_zeros, = promote_dtypes_inexact(seq_of_zeros)
seq_of_zeros_arr = atleast_1d(seq_of_zeros)
del seq_of_zeros
sh = seq_of_zeros_arr.shape
if len(sh) == 2 and sh[0] == sh[1] and sh[0] != 0:
# import at runtime to avoid circular import
from jax._src.numpy import linalg
seq_of_zeros_arr = linalg.eigvals(seq_of_zeros_arr)
if seq_of_zeros_arr.ndim != 1:
raise ValueError("input must be 1d or non-empty square 2d array.")
dt = seq_of_zeros_arr.dtype
if len(seq_of_zeros_arr) == 0:
return ones((), dtype=dt)
a = ones((1,), dtype=dt)
for k in range(len(seq_of_zeros_arr)):
a = convolve(a, array([1, -seq_of_zeros_arr[k]], dtype=dt), mode='full')
return a
[文档]
@partial(jit, static_argnames=['unroll'])
def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array:
r"""Evaluates the polynomial at specific values.
JAX implementations of :func:`numpy.polyval`.
For the 1D-polynomial coefficients ``p`` of length ``M``, the function returns
the value:
.. math::
p_0 x^{M - 1} + p_1 x^{M - 2} + ... + p_{M - 1}
Args:
p: An array of polynomial coefficients of shape ``(M,)``.
x: A number or an array of numbers.
unroll: A number used to control the number of unrolled steps with
``lax.scan``. It must be specified statically.
Returns:
An array of same shape as ``x``.
Note:
The ``unroll`` parameter is JAX specific. It does not affect correctness but
can have a major impact on performance for evaluating high-order polynomials.
The parameter controls the number of unrolled steps with ``lax.scan`` inside
the ``jnp.polyval`` implementation. Consider setting ``unroll=128`` (or even
higher) to improve runtime performance on accelerators, at the cost of
increased compilation time.
See also:
- :func:`jax.numpy.polyfit`: Least squares polynomial fit.
- :func:`jax.numpy.poly`: Finds the coefficients of a polynomial with given
roots.
- :func:`jax.numpy.roots`: Computes the roots of a polynomial for given
coefficients.
Example:
>>> p = jnp.array([2, 5, 1])
>>> jnp.polyval(p, 3)
Array(34., dtype=float32)
If ``x`` is a 2D array, ``polyval`` returns 2D-array with same shape as
that of ``x``:
>>> x = jnp.array([[2, 1, 5],
... [3, 4, 7],
... [1, 3, 5]])
>>> jnp.polyval(p, x)
Array([[ 19., 8., 76.],
[ 34., 53., 134.],
[ 8., 34., 76.]], dtype=float32)
"""
check_arraylike("polyval", p, x)
p_arr, x_arr = promote_dtypes_inexact(p, x)
del p, x
shape = lax.broadcast_shapes(p_arr.shape[1:], x_arr.shape)
y = lax.full_like(x_arr, 0, shape=shape, dtype=x_arr.dtype)
y, _ = lax.scan(lambda y, p: (y * x_arr + p, None), y, p_arr, unroll=unroll)
return y
[文档]
@jit
def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array:
r"""Returns the sum of the two polynomials.
JAX implementation of :func:`numpy.polyadd`.
Args:
a1: Array of polynomial coefficients.
a2: Array of polynomial coefficients.
Returns:
An array containing the coefficients of the sum of input polynomials.
Note:
:func:`jax.numpy.polyadd` only accepts arrays as input unlike
:func:`numpy.polyadd` which accepts scalar inputs as well.
See also:
- :func:`jax.numpy.polysub`: Computes the difference of two polynomials.
- :func:`jax.numpy.polymul`: Computes the product of two polynomials.
- :func:`jax.numpy.polydiv`: Computes the quotient and remainder of polynomial
division.
Example:
>>> x1 = jnp.array([2, 3])
>>> x2 = jnp.array([5, 4, 1])
>>> jnp.polyadd(x1, x2)
Array([5, 6, 4], dtype=int32)
>>> x3 = jnp.array([[2, 3, 1]])
>>> x4 = jnp.array([[5, 7, 3],
... [8, 2, 6]])
>>> jnp.polyadd(x3, x4)
Array([[ 5, 7, 3],
[10, 5, 7]], dtype=int32)
>>> x5 = jnp.array([1, 3, 5])
>>> x6 = jnp.array([[5, 7, 9],
... [8, 6, 4]])
>>> jnp.polyadd(x5, x6) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError: Cannot broadcast to shape with fewer dimensions: arr_shape=(2, 3) shape=(2,)
>>> x7 = jnp.array([2])
>>> jnp.polyadd(x6, x7)
Array([[ 5, 7, 9],
[10, 8, 6]], dtype=int32)
"""
check_arraylike("polyadd", a1, a2)
a1_arr, a2_arr = promote_dtypes(a1, a2)
del a1, a2
if a2_arr.shape[0] <= a1_arr.shape[0]:
return a1_arr.at[-a2_arr.shape[0]:].add(a2_arr)
else:
return a2_arr.at[-a1_arr.shape[0]:].add(a1_arr)
[文档]
@partial(jit, static_argnames=('m',))
def polyint(p: ArrayLike, m: int = 1, k: int | ArrayLike | None = None) -> Array:
r"""Returns the coefficients of the integration of specified order of a polynomial.
JAX implementation of :func:`numpy.polyint`.
Args:
p: An array of polynomial coefficients.
m: Order of integration. Default is 1. It must be specified statically.
k: Scalar or array of ``m`` integration constant (s).
Returns:
An array of coefficients of integrated polynomial.
See also:
- :func:`jax.numpy.polyder`: Computes the coefficients of the derivative of
a polynomial.
- :func:`jax.numpy.polyval`: Evaluates a polynomial at specific values.
Examples:
The first order integration of the polynomial :math:`12 x^2 + 12 x + 6` is
:math:`4 x^3 + 6 x^2 + 6 x`.
>>> p = jnp.array([12, 12, 6])
>>> jnp.polyint(p)
Array([4., 6., 6., 0.], dtype=float32)
Since the constant ``k`` is not provided, the result included ``0`` at the end.
If the constant ``k`` is provided:
>>> jnp.polyint(p, k=4)
Array([4., 6., 6., 4.], dtype=float32)
and the second order integration is :math:`x^4 + 2 x^3 + 3 x`:
>>> jnp.polyint(p, m=2)
Array([1., 2., 3., 0., 0.], dtype=float32)
When ``m>=2``, the constants ``k`` should be provided as an array having
``m`` elements. The second order integration of the polynomial
:math:`12 x^2 + 12 x + 6` with the constants ``k=[4, 5]`` is
:math:`x^4 + 2 x^3 + 3 x^2 + 4 x + 5`:
>>> jnp.polyint(p, m=2, k=jnp.array([4, 5]))
Array([1., 2., 3., 4., 5.], dtype=float32)
"""
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint")
k = 0 if k is None else k
check_arraylike("polyint", p, k)
p_arr, k_arr = promote_dtypes_inexact(p, k)
del p, k
if m < 0:
raise ValueError("Order of integral must be positive (see polyder)")
k_arr = atleast_1d(k_arr)
if len(k_arr) == 1:
k_arr = full((m,), k_arr[0])
if k_arr.shape != (m,):
raise ValueError("k must be a scalar or a rank-1 array of length 1 or m.")
if m == 0:
return p_arr
else:
grid = (arange(len(p_arr) + m, dtype=p_arr.dtype)[np.newaxis]
- arange(m, dtype=p_arr.dtype)[:, np.newaxis])
coeff = maximum(1, grid).prod(0)[::-1]
return true_divide(concatenate((p_arr, k_arr)), coeff)
[文档]
@partial(jit, static_argnames=('m',))
def polyder(p: ArrayLike, m: int = 1) -> Array:
r"""Returns the coefficients of the derivative of specified order of a polynomial.
JAX implementation of :func:`numpy.polyder`.
Args:
p: Array of polynomials coefficients.
m: Order of differentiation (positive integer). Default is 1. It must be
specified statically.
Returns:
An array of polynomial coefficients representing the derivative.
Note:
:func:`jax.numpy.polyder` differs from :func:`numpy.polyder` when an integer
array is given. NumPy returns the result with dtype ``int`` whereas JAX
returns the result with dtype ``float``.
See also:
- :func:`jax.numpy.polyint`: Computes the integral of polynomial.
- :func:`jax.numpy.polyval`: Evaluates a polynomial at specific values.
Examples:
The first order derivative of the polynomial :math:`2 x^3 - 5 x^2 + 3 x - 1`
is :math:`6 x^2 - 10 x +3`:
>>> p = jnp.array([2, -5, 3, -1])
>>> jnp.polyder(p)
Array([ 6., -10., 3.], dtype=float32)
and its second order derivative is :math:`12 x - 10`:
>>> jnp.polyder(p, m=2)
Array([ 12., -10.], dtype=float32)
"""
check_arraylike("polyder", p)
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder")
p_arr, = promote_dtypes_inexact(p)
del p
if m < 0:
raise ValueError("Order of derivative must be positive")
if m == 0:
return p_arr
coeff = (arange(m, len(p_arr), dtype=p_arr.dtype)[np.newaxis]
- arange(m, dtype=p_arr.dtype)[:, np.newaxis]).prod(0)
return p_arr[:-m] * coeff[::-1]
[文档]
def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -> Array:
r"""Returns the product of two polynomials.
JAX implementation of :func:`numpy.polymul`.
Args:
a1: 1D array of polynomial coefficients.
a2: 1D array of polynomial coefficients.
trim_leading_zeros: Default is ``False``. If ``True`` removes the leading
zeros in the return value to match the result of numpy. But prevents the
function from being able to be used in compiled code. Due to differences
in accumulation of floating point arithmetic errors, the cutoff for values
to be considered zero may lead to inconsistent results between NumPy and
JAX, and even between different JAX backends. The result may lead to
inconsistent output shapes when ``trim_leading_zeros=True``.
Returns:
An array of the coefficients of the product of the two polynomials. The dtype
of the output is always promoted to inexact.
Note:
:func:`jax.numpy.polymul` only accepts arrays as input unlike
:func:`numpy.polymul` which accepts scalar inputs as well.
See also:
- :func:`jax.numpy.polyadd`: Computes the sum of two polynomials.
- :func:`jax.numpy.polysub`: Computes the difference of two polynomials.
- :func:`jax.numpy.polydiv`: Computes the quotient and remainder of polynomial
division.
Example:
>>> x1 = np.array([2, 1, 0])
>>> x2 = np.array([0, 5, 0, 3])
>>> np.polymul(x1, x2)
array([10, 5, 6, 3, 0])
>>> jnp.polymul(x1, x2)
Array([ 0., 10., 5., 6., 3., 0.], dtype=float32)
If ``trim_leading_zeros=True``, the result matches with ``np.polymul``'s.
>>> jnp.polymul(x1, x2, trim_leading_zeros=True)
Array([10., 5., 6., 3., 0.], dtype=float32)
For input arrays of dtype ``complex``:
>>> x3 = np.array([2., 1+2j, 1-2j])
>>> x4 = np.array([0, 5, 0, 3])
>>> np.polymul(x3, x4)
array([10. +0.j, 5.+10.j, 11.-10.j, 3. +6.j, 3. -6.j])
>>> jnp.polymul(x3, x4)
Array([ 0. +0.j, 10. +0.j, 5.+10.j, 11.-10.j, 3. +6.j, 3. -6.j], dtype=complex64)
>>> jnp.polymul(x3, x4, trim_leading_zeros=True)
Array([10. +0.j, 5.+10.j, 11.-10.j, 3. +6.j, 3. -6.j], dtype=complex64)
"""
check_arraylike("polymul", a1, a2)
a1_arr, a2_arr = promote_dtypes_inexact(a1, a2)
del a1, a2
if trim_leading_zeros and (len(a1_arr) > 1 or len(a2_arr) > 1):
a1_arr, a2_arr = trim_zeros(a1_arr, trim='f'), trim_zeros(a2_arr, trim='f')
if len(a1_arr) == 0:
a1_arr = asarray([0], dtype=a2_arr.dtype)
if len(a2_arr) == 0:
a2_arr = asarray([0], dtype=a1_arr.dtype)
return convolve(a1_arr, a2_arr, mode='full')
[文档]
def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> tuple[Array, Array]:
r"""Returns the quotient and remainder of polynomial division.
JAX implementation of :func:`numpy.polydiv`.
Args:
u: Array of dividend polynomial coefficients.
v: Array of divisor polynomial coefficients.
trim_leading_zeros: Default is ``False``. If ``True`` removes the leading
zeros in the return value to match the result of numpy. But prevents the
function from being able to be used in compiled code. Due to differences
in accumulation of floating point arithmetic errors, the cutoff for values
to be considered zero may lead to inconsistent results between NumPy and
JAX, and even between different JAX backends. The result may lead to
inconsistent output shapes when ``trim_leading_zeros=True``.
Returns:
A tuple of quotient and remainder arrays. The dtype of the output is always
promoted to inexact.
Note:
:func:`jax.numpy.polydiv` only accepts arrays as input unlike
:func:`numpy.polydiv` which accepts scalar inputs as well.
See also:
- :func:`jax.numpy.polyadd`: Computes the sum of two polynomials.
- :func:`jax.numpy.polysub`: Computes the difference of two polynomials.
- :func:`jax.numpy.polymul`: Computes the product of two polynomials.
Example:
>>> x1 = jnp.array([5, 7, 9])
>>> x2 = jnp.array([4, 1])
>>> np.polydiv(x1, x2)
(array([1.25 , 1.4375]), array([7.5625]))
>>> jnp.polydiv(x1, x2)
(Array([1.25 , 1.4375], dtype=float32), Array([0. , 0. , 7.5625], dtype=float32))
If ``trim_leading_zeros=True``, the result matches with ``np.polydiv``'s.
>>> jnp.polydiv(x1, x2, trim_leading_zeros=True)
(Array([1.25 , 1.4375], dtype=float32), Array([7.5625], dtype=float32))
"""
check_arraylike("polydiv", u, v)
u_arr, v_arr = promote_dtypes_inexact(u, v)
del u, v
m = len(u_arr) - 1
n = len(v_arr) - 1
scale = 1. / v_arr[0]
q: Array = zeros(max(m - n + 1, 1), dtype = u_arr.dtype) # force same dtype
for k in range(0, m-n+1):
d = scale * u_arr[k]
q = q.at[k].set(d)
u_arr = u_arr.at[k:k+n+1].add(-d*v_arr)
if trim_leading_zeros:
# use the square root of finfo(dtype) to approximate the absolute tolerance used in numpy
u_arr = trim_zeros_tol(u_arr, tol=sqrt(finfo(u_arr.dtype).eps), trim='f')
return q, u_arr
[文档]
@jit
def polysub(a1: ArrayLike, a2: ArrayLike) -> Array:
r"""Returns the difference of two polynomials.
JAX implementation of :func:`numpy.polysub`.
Args:
a1: Array of minuend polynomial coefficients.
a2: Array of subtrahend polynomial coefficients.
Returns:
An array containing the coefficients of the difference of two polynomials.
Note:
:func:`jax.numpy.polysub` only accepts arrays as input unlike
:func:`numpy.polysub` which accepts scalar inputs as well.
See also:
- :func:`jax.numpy.polyadd`: Computes the sum of two polynomials.
- :func:`jax.numpy.polymul`: Computes the product of two polynomials.
- :func:`jax.numpy.polydiv`: Computes the quotient and remainder of polynomial
division.
Example:
>>> x1 = jnp.array([2, 3])
>>> x2 = jnp.array([5, 4, 1])
>>> jnp.polysub(x1, x2)
Array([-5, -2, 2], dtype=int32)
>>> x3 = jnp.array([[2, 3, 1]])
>>> x4 = jnp.array([[5, 7, 3],
... [8, 2, 6]])
>>> jnp.polysub(x3, x4)
Array([[-5, -7, -3],
[-6, 1, -5]], dtype=int32)
>>> x5 = jnp.array([1, 3, 5])
>>> x6 = jnp.array([[5, 7, 9],
... [8, 6, 4]])
>>> jnp.polysub(x5, x6) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError: Cannot broadcast to shape with fewer dimensions: arr_shape=(2, 3) shape=(2,)
>>> x7 = jnp.array([2])
>>> jnp.polysub(x6, x7)
Array([[5, 7, 9],
[6, 4, 2]], dtype=int32)
"""
check_arraylike("polysub", a1, a2)
a1, a2 = promote_dtypes(a1, a2)
return polyadd(a1, -a2)