jax.lax.approx_max_k

jax.lax.approx_max_k#

jax.lax.approx_max_k(operand, k, reduction_dimension=-1, recall_target=0.95, reduction_input_size_override=-1, aggregate_to_topk=True)[源代码][源代码]#

以近似的方式返回 operand 中的最大 k 个值及其索引。

算法细节请参见 https://arxiv.org/abs/2206.14286

参数:
  • operand (Array) – 用于搜索最大k的数组。必须是浮点数类型。

  • k (int) – 指定最大k值的数量。

  • reduction_dimension (int) – 要搜索的整数维度。默认值:-1。

  • recall_target (float) – 近似的目标值。

  • reduction_input_size_override (int) – 当设置为正值时,它会覆盖由 operand[reduction_dim] 确定的用于评估召回率的大小。当给定的 operand 只是SPMD或分布式管道中整体计算的一个子集,且真实输入大小无法通过操作数形状推断时,此选项非常有用。

  • aggregate_to_topk (bool) – 当为真时,将按排序顺序聚合前k个近似结果。当为假时,返回未排序的近似结果。在这种情况下,近似结果的数量是实现定义的,并且大于或等于指定的 k

返回:

两个数组的元组。这些数组是输入 operandreduction_dimension 上的最大 k 值及其对应的索引。数组的维度与输入 operand 相同,除了 reduction_dimension:当 aggregate_to_topk 为真时,缩减维度为 k;否则,它大于等于 k,其中大小由实现定义。

返回类型:

tuple[Array, Array]

我们鼓励用户用 jit 包装 approx_max_k。请参见以下最大内积搜索 (MIPS) 的示例:

>>> import functools
>>> import jax
>>> import numpy as np
>>> @functools.partial(jax.jit, static_argnames=["k", "recall_target"])
... def mips(qy, db, k=10, recall_target=0.95):
...   dists = jax.lax.dot(qy, db.transpose())
...   # returns (f32[qy_size, k], i32[qy_size, k])
...   return jax.lax.approx_max_k(dists, k=k, recall_target=recall_target)
>>>
>>> qy = jax.numpy.array(np.random.rand(50, 64))
>>> db = jax.numpy.array(np.random.rand(1024, 64))
>>> dot_products, neighbors = mips(qy, db, k=10)