jax.numpy.expm1

目录

jax.numpy.expm1#

jax.numpy.expm1(x, /)[源代码][源代码]#

计算输入的每个元素的 exp(x)-1

JAX implementation of numpy.expm1.

参数:

x (ArrayLike) – 输入数组或标量。

返回:

包含 x 中每个元素的 exp(x)-1 的数组,会提升为不精确的数据类型。

返回类型:

Array

备注

jnp.expm1 对于小值 x 的计算精度远高于 exp(x)-1 的朴素计算。

参见

示例

>>> x = jnp.array([2, -4, 3, -1])
>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.expm1(x))
[ 6.39 -0.98 19.09 -0.63]
>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.exp(x)-1)
[ 6.39 -0.98 19.09 -0.63]

对于非常接近 0 的值,jnp.expm1(x)jnp.exp(x)-1 要准确得多:

>>> x1 = jnp.array([1e-4, 1e-6, 2e-10])
>>> jnp.expm1(x1)
Array([1.0000500e-04, 1.0000005e-06, 2.0000000e-10], dtype=float32)
>>> jnp.exp(x1)-1
Array([1.00016594e-04, 9.53674316e-07, 0.00000000e+00], dtype=float32)