jax.numpy.polyval

目录

jax.numpy.polyval#

jax.numpy.polyval(p, x, *, unroll=16)[源代码][源代码]#

在特定值处计算多项式。

JAX 实现的 numpy.polyval()

对于长度为 M 的一维多项式系数 p ,该函数返回值:

\[p_0 x^{M - 1} + p_1 x^{M - 2} + ... + p_{M - 1}\]
参数:
  • p (ArrayLike) – 多项式系数的数组,形状为 (M,)

  • x (ArrayLike) – 一个数字或一组数字。

  • unroll (int) – 用于控制 lax.scan 展开步数的数字。它必须静态指定。

返回:

x 形状相同的数组。

返回类型:

Array

备注

unroll 参数是 JAX 特有的。它不影响正确性,但在评估高阶多项式时可能对性能产生重大影响。该参数控制 jnp.polyval 实现中 lax.scan 的展开步数。考虑设置 ``unroll=128``(甚至更高)以提高加速器上的运行时性能,但会增加编译时间。

参见

示例

>>> p = jnp.array([2, 5, 1])
>>> jnp.polyval(p, 3)
Array(34., dtype=float32)

如果 x 是一个二维数组,polyval 返回一个与 x 形状相同的二维数组:

>>> x = jnp.array([[2, 1, 5],
...                [3, 4, 7],
...                [1, 3, 5]])
>>> jnp.polyval(p, x)
Array([[ 19.,   8.,  76.],
       [ 34.,  53., 134.],
       [  8.,  34.,  76.]], dtype=float32)