jax.numpy.unique_all#
- jax.numpy.unique_all(x, /, *, size=None, fill_value=None)[源代码][源代码]#
从 x 中返回唯一值,以及索引、逆索引和计数。
JAX 实现
numpy.unique_all()
;这相当于调用jax.numpy.unique()
并将 return_index、return_inverse、return_counts 和 equal_nan 设置为 True。由于
unique_all
的输出大小取决于数据,该函数通常不兼容jit()
和其他 JAX 变换。JAX 版本增加了可选的size
参数,必须在jnp.unique
用于此类上下文中时静态指定该参数。- 参数:
x (ArrayLike) – 将从中提取唯一值的 N 维数组。
size (int | None) – 如果指定,则仅返回前
size
个排序后的唯一元素。如果唯一元素的数量少于size
指示的数量,则返回值将用fill_value
填充。fill_value (ArrayLike | None) – 当指定了
size
且元素数量少于指定数量时,用fill_value
填充剩余的条目。默认为最小唯一值。
- 返回:
values
:形状为
(n_unique,)
的数组,包含x
中的唯一值。
indices
:形状为
(n_unique,)
的数组。包含x
中每个唯一值首次出现的索引。对于一维输入,x[indices]
等价于values
。
inverse_indices
:形状为
x.shape
的数组。包含values
中每个x
值的索引。对于一维输入,values[inverse_indices]
等同于x
。
counts
:形状为
(n_unique,)
的数组。包含x
中每个唯一值的出现次数。
- 返回类型:
一个元组
(values, indices, inverse_indices, counts)
,具有以下属性
参见
jax.numpy.unique()
: 用于计算唯一值的通用函数。jax.numpy.unique_values()
: 仅计算values
。jax.numpy.unique_counts()
: 仅计算values
和counts
。jax.numpy.unique_inverse()
: 仅计算values
和inverse
。
示例
这里我们计算一维数组中的唯一值:
>>> x = jnp.array([3, 4, 1, 3, 1]) >>> result = jnp.unique_all(x)
结果是一个具有四个命名属性的
NamedTuple
。values
属性包含数组中的唯一值:>>> result.values Array([1, 3, 4], dtype=int32)
indices
属性包含输入数组中唯一values
的索引:>>> result.indices Array([2, 0, 1], dtype=int32) >>> jnp.all(result.values == x[result.indices]) Array(True, dtype=bool)
inverse_indices
属性包含输入在values
中的索引:>>> result.inverse_indices Array([1, 2, 0, 1, 0], dtype=int32) >>> jnp.all(x == result.values[result.inverse_indices]) Array(True, dtype=bool)
counts
属性包含输入中每个唯一值的计数:>>> result.counts Array([2, 2, 1], dtype=int32)
关于
size
和fill_value
参数的示例,请参见jax.numpy.unique()
。