jax.numpy.logaddexp

目录

jax.numpy.logaddexp#

jax.numpy.logaddexp(x1, x2, /)[源代码][源代码]#

计算 log(exp(x1) + exp(x2)) 以避免溢出。

JAX implementation of numpy.logaddexp

参数:
  • x1 (ArrayLike) – 输入数组

  • x2 (ArrayLike) – 输入数组

返回:

包含结果的数组。

返回类型:

Array

示例:

>>> x1 = jnp.array([1, 2, 3])
>>> x2 = jnp.array([4, 5, 6])
>>> result1 = jnp.logaddexp(x1, x2)
>>> result2 = jnp.log(jnp.exp(x1) + jnp.exp(x2))
>>> print(jnp.allclose(result1, result2))
True