jax.numpy.indices

目录

jax.numpy.indices#

jax.numpy.indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, sparse: Literal[False] = False) Array[源代码][源代码]#
jax.numpy.indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, *, sparse: Literal[True]) tuple[Array, ...]
jax.numpy.indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, sparse: bool = False) Array | tuple[Array, ...]

返回一个表示网格索引的数组。

LAX-backend 实现的 numpy.indices()

原始文档字符串如下。

计算一个数组,其中子数组包含索引值 0, 1, … 仅沿相应的轴变化。

参数:
  • dimensions (sequence of ints) – 网格的形状。

  • dtype (dtype, optional) – 结果的数据类型。

  • sparse (boolean, optional) – 返回网格的稀疏表示,而不是密集表示。默认是 False。

返回:

网格 – 如果 sparse 为 False: 返回一个网格索引数组, grid.shape = (len(dimensions),) + tuple(dimensions)。如果 sparse 为 True: 返回一个数组元组,其中 grid[i].shape = (1, ..., 1, dimensions[i], 1, ..., 1), dimensions[i] 位于第 i 个位置

返回类型:

one ndarray or tuple of ndarrays