jax.numpy.choose

目录

jax.numpy.choose#

jax.numpy.choose(a, choices, out=None, mode='raise')[源代码][源代码]#

从索引数组和选择数组列表中构造一个数组。

LAX-backend 对 numpy.choose() 的实现。

原始文档字符串如下。

首先,如果有困惑或不确定,一定要查看示例 - 在完全通用的情况下,这个函数比从下面的代码描述中看起来要复杂一些(下面 ndi = numpy.lib.index_tricks):

np.choose(a,c) == np.array([c[a[I]][I] for I in ndi.ndindex(a.shape)]).

但这忽略了一些微妙之处。以下是一个全面的总结:

给定一个整数数组“index”(a)和一系列 n 个数组(choices),a 和每个选择数组首先根据需要广播为具有公共形状的数组;称这些为 BaBchoices[i], i = 0,…,n-1,我们必然有 Ba.shape == Bchoices[i].shape 对于每个 i。然后,创建一个形状为 Ba.shape 的新数组,如下所示:

  • 如果 mode='raise' (默认),那么首先,a 的每个元素(因此 Ba 的每个元素)必须在范围 [0, n-1] 内;现在,假设 i (在该范围内)是 Ba 中位置 (j0, j1, ..., jm) 处的值 - 那么新数组中相同位置的值是 Bchoices[i] 中相同位置的值;

  • 如果 mode='wrap'a 中的值(因此 Ba 中的值)可以是任何(有符号)整数;使用模运算将范围 [0, n-1] 之外的整数映射回该范围;然后按照上述方式构造新数组;

  • 如果 mode='clip'a 中的值(因此 Ba 中的值)可以是任何(有符号)整数;负整数映射为 0;大于 n-1 的值映射为 n-1;然后按照上述方式构建新数组。

参数:
  • a (int array) – 这个数组必须包含 [0, n-1] 范围内的整数,其中 n 是选项的数量,除非 mode=wrapmode=clip,在这种情况下,任何整数都是允许的。

  • choices (sequence of arrays) – 选择数组。a 和所有选择必须能够广播到相同的形状。如果 choices 本身是一个数组(不推荐),那么它的最外层维度(即对应于 choices.shape[0] 的维度)被视为定义“序列”。

  • mode ({'raise' (default), 'wrap', 'clip'}, optional) – 指定如何处理超出 [0, n-1] 范围的索引:* ‘raise’ : 引发异常 * ‘wrap’ : 值变为值 mod n * ‘clip’ : 值 < 0 映射到 0,值 > n-1 映射到 n-1

  • out (None)

返回:

merged_array – 合并后的结果。

返回类型:

array