jax.lax.collapse

目录

jax.lax.collapse#

jax.lax.collapse(operand, start_dimension, stop_dimension=None)[源代码][源代码]#

将数组的维度折叠成单一维度。

例如,如果 operand 是一个形状为 [2, 3, 4] 的数组,collapse(operand, 0, 2).shape == [6, 4]。被折叠维度的元素按主到次的顺序排列,即,以编号最低的维度作为变化最慢的维度。

参数:
  • operand (Array) – 一个输入数组。

  • start_dimension (int) – 要折叠的维度的起始位置(包括该位置)。

  • stop_dimension (int | None) – 要折叠的维度末尾(不包括)。传递 None 以折叠开始后的所有维度。

返回:

一个数组,其中维度 [start_dimension, stop_dimension) 已经被折叠(展平)成一个单一维度。

返回类型:

Array