jax.numpy.ix_

目录

jax.numpy.ix_#

jax.numpy.ix_(*args)[源代码][源代码]#

从 N 个一维序列返回一个多维网格(开放网格)。

JAX 实现的 numpy.ix_()

参数:

*args (ArrayLike) – N 一维数组

返回:

由 Jax 数组组成的开放网格元组,每个数组具有 N 个维度。

返回类型:

tuple[Array, …]

示例

>>> rows = jnp.array([0, 2])
>>> cols = jnp.array([1, 3])
>>> open_mesh = jnp.ix_(rows, cols)
>>> open_mesh
(Array([[0],
      [2]], dtype=int32), Array([[1, 3]], dtype=int32))
>>> [grid.shape for grid in open_mesh]
[(2, 1), (1, 2)]
>>> x = jnp.array([[10, 20, 30, 40],
...                [50, 60, 70, 80],
...                [90, 100, 110, 120],
...                [130, 140, 150, 160]])
>>> x[open_mesh]
Array([[ 20,  40],
       [100, 120]], dtype=int32)