jax.lax.expand_dims# jax.lax.expand_dims(array, dimensions)[源代码][源代码]# 在数组中插入任意数量的尺寸为1的维度。 参数: array (ArrayLike) dimensions (Sequence[int]) 返回类型: Array