jax.lax.axis_index#
- jax.lax.axis_index(axis_name)[源代码][源代码]#
返回沿映射轴
axis_name
的索引。- 参数:
axis_name – 用于命名映射轴的可哈希 Python 对象。
- 返回:
表示索引的整数。
例如,如果有8个可用的XLA设备:
>>> from functools import partial >>> @partial(jax.pmap, axis_name='i') ... def f(_): ... return lax.axis_index('i') ... >>> f(np.zeros(4)) Array([0, 1, 2, 3], dtype=int32) >>> f(np.zeros(8)) Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32) >>> @partial(jax.pmap, axis_name='i') ... @partial(jax.pmap, axis_name='j') ... def f(_): ... return lax.axis_index('i'), lax.axis_index('j') ... >>> x, y = f(np.zeros((4, 2))) >>> print(x) [[0 0] [1 1] [2 2] [3 3]] >>> print(y) [[0 1] [0 1] [0 1] [0 1]]