jax.numpy.indices#
- jax.numpy.indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, sparse: Literal[False] = False) Array [源代码][源代码]#
- jax.numpy.indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, *, sparse: Literal[True]) tuple[Array, ...]
- jax.numpy.indices(dimensions: Sequence[int], dtype: DTypeLike | None = None, sparse: bool = False) Array | tuple[Array, ...]
返回一个表示网格索引的数组。
LAX-backend 实现的
numpy.indices()
。原始文档字符串如下。
计算一个数组,其中子数组包含索引值 0, 1, … 仅沿相应的轴变化。
- 参数:
dimensions (sequence of ints) – 网格的形状。
dtype (dtype, optional) – 结果的数据类型。
sparse (boolean, optional) – 返回网格的稀疏表示,而不是密集表示。默认是 False。
- 返回:
网格 – 如果 sparse 为 False: 返回一个网格索引数组,
grid.shape = (len(dimensions),) + tuple(dimensions)
。如果 sparse 为 True: 返回一个数组元组,其中grid[i].shape = (1, ..., 1, dimensions[i], 1, ..., 1)
, dimensions[i] 位于第 i 个位置- 返回类型:
one ndarray or tuple of ndarrays