jax.numpy.logaddexp#
- jax.numpy.logaddexp(x1, x2, /)[源代码][源代码]#
计算
log(exp(x1) + exp(x2))
以避免溢出。JAX implementation of
numpy.logaddexp
- 参数:
x1 (ArrayLike) – 输入数组
x2 (ArrayLike) – 输入数组
- 返回:
包含结果的数组。
- 返回类型:
示例:
>>> x1 = jnp.array([1, 2, 3]) >>> x2 = jnp.array([4, 5, 6]) >>> result1 = jnp.logaddexp(x1, x2) >>> result2 = jnp.log(jnp.exp(x1) + jnp.exp(x2)) >>> print(jnp.allclose(result1, result2)) True