jax.scipy.special.log_ndtr

目录

jax.scipy.special.log_ndtr#

jax.scipy.special.log_ndtr = <jax._src.custom_derivatives.custom_jvp object>[源代码][源代码]#

对数正态分布函数。

JAX implementation of scipy.special.log_ndtr.

有关正态分布函数的详细信息,请参见 ndtr

此函数通过调用 \(\log(\mathrm{ndtr}(x))\) 或使用渐近级数来计算 \(\log(\mathrm{ndtr}(x))\)。具体来说:

  • 对于 x > upper_segment,使用基于 \(\log(1-x) \approx -x, x \ll 1\) 的近似 -ndtr(-x)

  • 对于 lower_segment < x <= upper_segment,使用现有的 ndtr 技术并取对数。

  • 对于 x <= lower_segment,我们使用 erf 的级数近似来直接计算对数累积分布函数。

lower_segment 的设置基于输入的精度:

\[\begin{split}\begin{align} \mathit{lower\_segment} =& \begin{cases} -20 & x.\mathrm{dtype}=\mathit{float64} \\ -10 & x.\mathrm{dtype}=\mathit{float32} \\ \end{cases} \\ \mathit{upper\_segment} =& \begin{cases} 8& x.\mathrm{dtype}=\mathit{float64} \\ 5& x.\mathrm{dtype}=\mathit{float32} \\ \end{cases} \end{align}\end{split}\]

x < lower_segment 时,ndtr 的渐近级数近似为:

\[\begin{split}\begin{align} \mathrm{ndtr}(x) =&\ \mathit{scale} * (1 + \mathit{sum}) + R_N \\ \mathit{scale} =&\ \frac{e^{-0.5 x^2}}{-x \sqrt{2 \pi}} \\ \mathit{sum} =&\ \sum_{n=1}^N {-1}^n (2n-1)!! / (x^2)^n \\ R_N =&\ O(e^{-0.5 x^2} (2N+1)!! / |x|^{2N+3}) \end{align}\end{split}\]

其中 \((2n-1)!! = (2n-1) (2n-3) (2n-5) ... (3) (1)\) 是一个 双阶乘 运算符。

参数:
  • x (ArrayLike) – 类型为 float32float64 的数组。

  • series_order (int) – 正的Python整数。评估渐近展开的最大深度。这就是上面的 N

返回:

一个带有 dtype=x.dtype 的数组。

抛出:
  • TypeError – 如果 x.dtype 未被处理。

  • TypeError – 如果 series_order 不是一个 Python 整数

  • ValueError – 如果 series_order 不在 [0, 30] 范围内。

返回类型:

Array