jax.numpy.choose#
- jax.numpy.choose(a, choices, out=None, mode='raise')[源代码][源代码]#
从索引数组和选择数组列表中构造一个数组。
LAX-backend 对
numpy.choose()
的实现。原始文档字符串如下。
首先,如果有困惑或不确定,一定要查看示例 - 在完全通用的情况下,这个函数比从下面的代码描述中看起来要复杂一些(下面 ndi = numpy.lib.index_tricks):
np.choose(a,c) == np.array([c[a[I]][I] for I in ndi.ndindex(a.shape)])
.但这忽略了一些微妙之处。以下是一个全面的总结:
给定一个整数数组“index”(a)和一系列
n
个数组(choices),a 和每个选择数组首先根据需要广播为具有公共形状的数组;称这些为 Ba 和 Bchoices[i], i = 0,…,n-1,我们必然有Ba.shape == Bchoices[i].shape
对于每个i
。然后,创建一个形状为Ba.shape
的新数组,如下所示:如果
mode='raise'
(默认),那么首先,a
的每个元素(因此Ba
的每个元素)必须在范围[0, n-1]
内;现在,假设i
(在该范围内)是Ba
中位置(j0, j1, ..., jm)
处的值 - 那么新数组中相同位置的值是Bchoices[i]
中相同位置的值;如果
mode='wrap'
,a 中的值(因此 Ba 中的值)可以是任何(有符号)整数;使用模运算将范围 [0, n-1] 之外的整数映射回该范围;然后按照上述方式构造新数组;如果
mode='clip'
,a 中的值(因此Ba
中的值)可以是任何(有符号)整数;负整数映射为 0;大于n-1
的值映射为n-1
;然后按照上述方式构建新数组。
- 参数:
a (int array) – 这个数组必须包含
[0, n-1]
范围内的整数,其中n
是选项的数量,除非mode=wrap
或mode=clip
,在这种情况下,任何整数都是允许的。choices (sequence of arrays) – 选择数组。a 和所有选择必须能够广播到相同的形状。如果 choices 本身是一个数组(不推荐),那么它的最外层维度(即对应于
choices.shape[0]
的维度)被视为定义“序列”。mode ({'raise' (default), 'wrap', 'clip'}, optional) – 指定如何处理超出
[0, n-1]
范围的索引:* ‘raise’ : 引发异常 * ‘wrap’ : 值变为值 modn
* ‘clip’ : 值 < 0 映射到 0,值 > n-1 映射到 n-1out (None)
- 返回:
merged_array – 合并后的结果。
- 返回类型:
array