jax.numpy.nanmedian#
- jax.numpy.nanmedian(a, axis=None, out=None, overwrite_input=False, keepdims=False)[源代码][源代码]#
返回沿给定轴的数组元素的中位数,忽略 NaNs。
JAX 实现的
numpy.nanmedian()
。- 参数:
- 返回:
包含沿指定轴的中位数,忽略 NaNs 的数组。如果沿指定轴的所有元素都是 NaNs,则返回
nan
。- 返回类型:
参见
jax.numpy.nanmean()
: 计算给定轴上数组元素的平均值,忽略 NaNs。jax.numpy.nanmax()
: 计算数组元素在给定轴上的最大值,忽略 NaN。jax.numpy.nanmin()
: 计算给定轴上数组元素的最小值,忽略 NaNs。
示例
默认情况下,中位数是针对展平的数组计算的。
>>> nan = jnp.nan >>> x = jnp.array([[2, nan, 7, nan], ... [nan, 5, 9, 2], ... [6, 1, nan, 3]]) >>> jnp.nanmedian(x) Array(4., dtype=float32)
如果
axis=1
,则沿轴 1 计算中位数。>>> jnp.nanmedian(x, axis=1) Array([4.5, 5. , 3. ], dtype=float32)
如果
keepdims=True
,输出的ndim
等于输入的ndim
。>>> jnp.nanmedian(x, axis=1, keepdims=True) Array([[4.5], [5. ], [3. ]], dtype=float32)