jax.numpy.median#
- jax.numpy.median(a, axis=None, out=None, overwrite_input=False, keepdims=False)[源代码][源代码]#
返回沿给定轴的数组元素的中位数。
JAX 实现的
numpy.median()
。- 参数:
- 返回:
沿着给定轴的中位数数组。
- 返回类型:
参见
jax.numpy.mean()
: 计算数组元素在给定轴上的平均值。jax.numpy.max()
: 计算数组元素在给定轴上的最大值。jax.numpy.min()
: 计算数组元素在给定轴上的最小值。
示例
默认情况下,中位数是针对展平的数组计算的。
>>> x = jnp.array([[2, 4, 7, 1], ... [3, 5, 9, 2], ... [6, 1, 8, 3]]) >>> jnp.median(x) Array(3.5, dtype=float32)
如果
axis=1
,则沿轴 1 计算中位数。>>> jnp.median(x, axis=1) Array([3. , 4. , 4.5], dtype=float32)
如果
keepdims=True
,输出的ndim
等于输入的ndim
。>>> jnp.median(x, axis=1, keepdims=True) Array([[3. ], [4. ], [4.5]], dtype=float32)