jax.numpy.mask_indices#
- jax.numpy.mask_indices(*args, **kwargs)[源代码]#
返回一个掩码函数的索引,以访问 (n, n) 数组。
LAX-后端实现
numpy.mask_indices()
。原始文档字符串如下。
假设 mask_func 是一个函数,对于一个大小为
(n, n)
的方形数组 a 和一个可能的偏移参数 k,当调用为mask_func(a, k)
时,返回一个在某些位置为零的新数组(像 triu 或 tril 这样的函数正是这样做的)。然后,该函数返回非零值所在位置的索引。