jax.numpy.mask_indices

jax.numpy.mask_indices#

jax.numpy.mask_indices(*args, **kwargs)[源代码]#

返回一个掩码函数的索引,以访问 (n, n) 数组。

LAX-后端实现 numpy.mask_indices()

原始文档字符串如下。

假设 mask_func 是一个函数,对于一个大小为 (n, n) 的方形数组 a 和一个可能的偏移参数 k,当调用为 mask_func(a, k) 时,返回一个在某些位置为零的新数组(像 triutril 这样的函数正是这样做的)。然后,该函数返回非零值所在位置的索引。

参数:
  • n (int) – 返回的索引将有效访问形状为 (n, n) 的数组。

  • mask_func (callable) – 一个调用签名类似于 triutril 的函数。即,mask_func(x, k) 返回一个布尔数组,形状与 x 相同。k 是函数的可选参数。

  • k (scalar) – 一个可选参数,传递给 mask_func。像 triutril 这样的函数接受第二个参数,该参数被解释为偏移量。

返回:

indices – 对应于 mask_func(np.ones((n, n)), k) 为 True 的位置的 n 个索引数组。

返回类型:

tuple of arrays.