jax.Array.take

目录

jax.Array.take#

abstract Array.take(indices, axis=None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None)[源代码]#

从数组中提取元素。

请参阅 jax.numpy.take() 获取完整文档。

参数:
  • self (Array)

  • indices (ArrayLike)

  • axis (int | None)

  • out (None)

  • mode (str | None)

  • unique_indices (bool)

  • indices_are_sorted (bool)

  • fill_value (StaticScalar | None)

返回类型:

Array