jax.ShapeDtypeStruct

jax.ShapeDtypeStruct#

class jax.ShapeDtypeStruct(shape, dtype, named_shape=None, sharding=None, weak_type=False)[源代码][源代码]#

用于存储数组的形状、数据类型和其他静态属性的容器。

ShapeDtypeStruct 通常与 jax.eval_shape() 一起使用。

参数:
  • shape – 一个表示数组形状的整数序列

  • dtype – 一个类似数据类型的对象

  • sharding – (可选) 一个 jax.Sharding 对象

__init__(shape, dtype, named_shape=None, sharding=None, weak_type=False)[源代码][源代码]#

方法

__init__(shape, dtype[, named_shape, ...])

属性

shape

dtype

sharding

weak_type

layout

named_shape

ndim

size