jax.scipy.stats.mode

目录

jax.scipy.stats.mode#

jax.scipy.stats.mode(a, axis=0, nan_policy='propagate', keepdims=False)[源代码][源代码]#

计算数组沿轴的模式(最常见的值)。

JAX 实现的 scipy.stats.mode()

参数:
  • a (ArrayLike) – 类数组

  • axis (int | None) – int, 默认=0。计算众数的轴。

  • nan_policy (str) – str. JAX 仅支持 "propagate"

  • keepdims (bool) – bool, 默认=False。如果为真,缩减的轴将保留在结果中,尺寸为1。

返回:

一个数组的元组,(mode, count)mode 是众数值的数组,count 是每个值在输入数组中出现的次数。

返回类型:

ModeResult

示例

>>> x = jnp.array([2, 4, 1, 1, 3, 4, 4, 2, 3])
>>> mode, count = jax.scipy.stats.mode(x)
>>> mode, count
(Array(4, dtype=int32), Array(3, dtype=int32))

对于多维数组,jax.scipy.stats.mode 计算沿 axis=0mode 及其对应的 count

>>> x1 = jnp.array([[1, 2, 1, 3, 2, 1],
...                 [3, 1, 3, 2, 1, 3],
...                 [1, 2, 2, 3, 1, 2]])
>>> mode, count = jax.scipy.stats.mode(x1)
>>> mode, count
(Array([1, 2, 1, 3, 1, 1], dtype=int32), Array([2, 2, 1, 2, 2, 1], dtype=int32))

如果 axis=1modecount 将沿着 axis 1 计算。

>>> mode, count = jax.scipy.stats.mode(x1, axis=1)
>>> mode, count
(Array([1, 3, 2], dtype=int32), Array([3, 3, 3], dtype=int32))

默认情况下,jax.scipy.stats.mode 会减少结果的维度。要使结果的维度与输入数组相同,必须将参数 keepdims 设置为 True

>>> mode, count = jax.scipy.stats.mode(x1, axis=1, keepdims=True)
>>> mode, count
(Array([[1],
       [3],
       [2]], dtype=int32), Array([[3],
       [3],
       [3]], dtype=int32))