jax.numpy.polyval#
- jax.numpy.polyval(p, x, *, unroll=16)[源代码][源代码]#
在特定值处计算多项式。
JAX 实现的
numpy.polyval()
。对于长度为
M
的一维多项式系数p
,该函数返回值:\[p_0 x^{M - 1} + p_1 x^{M - 2} + ... + p_{M - 1}\]- 参数:
p (ArrayLike) – 多项式系数的数组,形状为
(M,)
。x (ArrayLike) – 一个数字或一组数字。
unroll (int) – 用于控制
lax.scan
展开步数的数字。它必须静态指定。
- 返回:
与
x
形状相同的数组。- 返回类型:
备注
unroll
参数是 JAX 特有的。它不影响正确性,但在评估高阶多项式时可能对性能产生重大影响。该参数控制jnp.polyval
实现中lax.scan
的展开步数。考虑设置 ``unroll=128``(甚至更高)以提高加速器上的运行时性能,但会增加编译时间。参见
jax.numpy.polyfit()
: 最小二乘多项式拟合。jax.numpy.poly()
: 找到具有给定根的多项式的系数。jax.numpy.roots()
: 计算给定系数的多项式的根。
示例
>>> p = jnp.array([2, 5, 1]) >>> jnp.polyval(p, 3) Array(34., dtype=float32)
如果
x
是一个二维数组,polyval
返回一个与x
形状相同的二维数组:>>> x = jnp.array([[2, 1, 5], ... [3, 4, 7], ... [1, 3, 5]]) >>> jnp.polyval(p, x) Array([[ 19., 8., 76.], [ 34., 53., 134.], [ 8., 34., 76.]], dtype=float32)