jax.numpy.var

目录

jax.numpy.var#

jax.numpy.var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None, correction=None)[源代码][源代码]#

计算沿指定轴的方差。

JAX 实现的 numpy.var()

参数:
  • a (ArrayLike) – 输入数组。

  • axis (Axis) – 可选,整数或整数序列,默认=None。计算方差的轴。如果为None,则沿所有轴计算方差。

  • dtype (DTypeLike | None) – 输出数组的类型。默认=None。

  • ddof (int) – int, 默认=0。自由度。在方差计算中的除数是 N-ddofN 是给定轴上的元素数量。

  • keepdims (bool) – bool, 默认=False。如果为真,缩减的轴将保留在结果中,尺寸为1。

  • where (ArrayLike | None) – 可选,布尔数组,默认=None。用于计算方差的元素。数组应与输入广播兼容。

  • correction (int | float | None) – int 或 float, 默认=None. ddof 的替代名称。ddof 和 correction 不能同时提供。

  • out (None) – 未被 JAX 使用。

返回:

沿给定轴的方差数组。

返回类型:

Array

参见

示例

默认情况下,jnp.var 计算所有轴上的方差。

>>> x = jnp.array([[1, 3, 4, 2],
...                [5, 2, 6, 3],
...                [8, 4, 2, 9]])
>>> with jnp.printoptions(precision=2, suppress=True):
...   jnp.var(x)
Array(5.74, dtype=float32)

如果 axis=1,则沿轴 1 计算方差。

>>> jnp.var(x, axis=1)
Array([1.25  , 2.5   , 8.1875], dtype=float32)

要保留输入的维度,可以设置 keepdims=True

>>> jnp.var(x, axis=1, keepdims=True)
Array([[1.25  ],
       [2.5   ],
       [8.1875]], dtype=float32)

如果 ddof=1

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.var(x, axis=1, keepdims=True, ddof=1))
[[ 1.67]
 [ 3.33]
 [10.92]]

要包含数组中的特定元素来计算方差,可以使用 where

>>> where = jnp.array([[1, 0, 1, 0],
...                    [0, 1, 1, 0],
...                    [1, 1, 1, 0]], dtype=bool)
>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.var(x, axis=1, keepdims=True, where=where))
[[2.25]
 [4.  ]
 [6.22]]