jax.numpy.roots

目录

jax.numpy.roots#

jax.numpy.roots(p, *, strip_zeros=True)[源代码][源代码]#

返回给定系数 p 的多项式的根。

JAX 实现的 numpy.roots()

参数:
  • p (ArrayLike) – 具有秩-1的多项式系数数组。

  • strip_zeros (bool) – bool, 默认=True。如果为True,则系数中的前导零将被去除,类似于 numpy.roots()。如果设置为False,前导零将不会被去除,并且未定义的根将在函数输出中表示为NaN值。strip_zeros 必须设置为 False 以使函数与 jax.jit() 和其他JAX变换兼容。

返回:

包含多项式根的数组。

返回类型:

Array

备注

与该函数的 np.roots 不同,jnp.roots 返回的根无论根的值如何,都在一个复数数组中。

参见

示例

>>> coeffs = jnp.array([0, 1, 2])

默认行为与 numpy 匹配,并去除前导零:

>>> jnp.roots(coeffs)
Array([-2.+0.j], dtype=complex64)

使用 strip_zeros=False 时,额外的根被设置为 NaN:

>>> jnp.roots(coeffs, strip_zeros=False)
Array([-2. +0.j, nan+nanj], dtype=complex64)