jax.lax.expand_dims

目录

jax.lax.expand_dims#

jax.lax.expand_dims(array, dimensions)[源代码][源代码]#

在数组中插入任意数量的尺寸为1的维度。

参数:
  • array (ArrayLike)

  • dimensions (Sequence[int])

返回类型:

Array