jax.numpy.count_nonzero#
- jax.numpy.count_nonzero(a, axis=None, keepdims=False)[源代码][源代码]#
返回沿指定轴的非零元素的数量。
JAX 实现的
numpy.count_nonzero()
。- 参数:
a (ArrayLike) – 输入数组。
axis (Axis) – 可选,整数或整数序列,默认=None。沿此轴计算非零元素的数量。如果为None,则在展平数组中计数。
keepdims (bool) – bool, 默认=False。如果为真,缩减的轴将保留在结果中,尺寸为1。
- 返回:
输入数组中指定轴上非零元素的数量数组。
- 返回类型:
示例
默认情况下,
jnp.count_nonzero
会计算所有轴上的非零值。>>> x = jnp.array([[1, 0, 0, 0], ... [0, 0, 1, 0], ... [1, 1, 1, 0]]) >>> jnp.count_nonzero(x) Array(5, dtype=int32)
如果
axis=1
,则沿轴 1 进行计数。>>> jnp.count_nonzero(x, axis=1) Array([1, 1, 3], dtype=int32)
要保留输入的维度,可以设置
keepdims=True
。>>> jnp.count_nonzero(x, axis=1, keepdims=True) Array([[1], [1], [3]], dtype=int32)