jax.lax.index_take# jax.lax.index_take(src, idxs, axes)[源代码][源代码]# 参数: src (Array) idxs (Array) axes (Sequence[int]) 返回类型: Array