jax.numpy.reshape#
- jax.numpy.reshape(a, shape=None, order='C', *, newshape=Deprecated, copy=None)[源代码][源代码]#
返回数组的重新形状的副本。
JAX 实现
numpy.reshape(),基于jax.lax.reshape()实现。- 参数:
a (ArrayLike) – 重塑的输入数组
shape (DimSize | Shape | None) – 整数或整数序列,给出新的形状,必须与输入数组的大小匹配。如果任何单个维度给定大小
-1,它将被替换为一个值,使得输出具有正确的尺寸。order (str) –
'F'或'C',指定重塑应采用列优先(Fortran 风格,'F')还是行优先(C 风格,'C')顺序;默认是'C'。JAX 不支持order='A'。copy (bool | None) – JAX 不使用;JAX 总是返回一个副本,尽管在 JIT 下编译器可能会优化这些副本。
newshape (DimSize | Shape | DeprecatedArg) –
shape参数的已弃用别名。如果使用,将导致DeprecationWarning。
- 返回:
具有指定形状的输入数组的重新整形副本。
- 返回类型:
备注
与
numpy.reshape()不同,jax.numpy.reshape()将返回输入数组的副本而不是视图。然而,在 JIT 下,编译器会在可能的情况下优化掉这些副本,因此在实践中不会影响性能。参见
jax.Array.reshape(): 通过数组方法实现的等效功能。jax.numpy.ravel(): 将数组展平为1D形状。jax.numpy.squeeze(): 从数组的形状中移除一个或多个长度为1的轴。
示例
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.reshape(x, 6) Array([1, 2, 3, 4, 5, 6], dtype=int32) >>> jnp.reshape(x, (3, 2)) Array([[1, 2], [3, 4], [5, 6]], dtype=int32)
你可以使用
-1来自动计算一个与输入大小一致的形状:>>> jnp.reshape(x, -1) # -1 is inferred to be 6 Array([1, 2, 3, 4, 5, 6], dtype=int32) >>> jnp.reshape(x, (-1, 2)) # -1 is inferred to be 3 Array([[1, 2], [3, 4], [5, 6]], dtype=int32)
reshape 中轴的默认排序是 C 风格的行优先排序。要使用 Fortran 风格的列优先排序,请指定
order='F':>>> jnp.reshape(x, 6, order='F') Array([1, 4, 2, 5, 3, 6], dtype=int32) >>> jnp.reshape(x, (3, 2), order='F') Array([[1, 5], [4, 3], [2, 6]], dtype=int32)
为了方便,此功能也可以通过
jax.Array.reshape()方法来实现:>>> x.reshape(3, 2) Array([[1, 2], [3, 4], [5, 6]], dtype=int32)