jax.Array.take# abstract Array.take(indices, axis=None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None)[源代码]# 从数组中提取元素。 请参阅 jax.numpy.take() 获取完整文档。 参数: self (Array) indices (ArrayLike) axis (int | None) out (None) mode (str | None) unique_indices (bool) indices_are_sorted (bool) fill_value (StaticScalar | None) 返回类型: Array