jax.numpy.sort

目录

jax.numpy.sort#

jax.numpy.sort(a, axis=-1, *, kind=None, order=None, stable=True, descending=False)[源代码][源代码]#

返回数组的排序副本。

JAX 实现的 numpy.sort()

参数:
  • a (ArrayLike) – 待排序的数组

  • axis (int | None) – 沿其排序的整数轴。默认为 -1,即最后一个轴。如果为 None,则在排序前 a 会被展平。

  • stable (bool) – 指定是否应使用稳定排序的布尔值。默认=True。

  • descending (bool) – 指定是否按降序排序的布尔值。默认值=False。

  • kind (None) – 已弃用;请改用 stable=True 或 stable=False 指定排序算法。

  • order (None) – 不受 JAX 支持

返回:

形状为 a.shape 的排序数组(如果 axis 是整数)或形状为 (a.size,) 的排序数组(如果 axis 是 None)。

返回类型:

Array

示例

简单的一维排序

>>> x = jnp.array([1, 3, 5, 4, 2, 1])
>>> jnp.sort(x)
Array([1, 1, 2, 3, 4, 5], dtype=int32)

沿数组的最后一个轴排序:

>>> x = jnp.array([[2, 1, 3],
...                [4, 3, 6]])
>>> jnp.sort(x, axis=1)
Array([[1, 2, 3],
       [3, 4, 6]], dtype=int32)

参见