jax.experimental.pallas.GridSpec

jax.experimental.pallas.GridSpec#

class jax.experimental.pallas.GridSpec(grid=(), in_specs=NoBlockSpec, out_specs=NoBlockSpec)[源代码][源代码]#

jax.experimental.pallas.pallas_call() 编码网格参数。

请参阅 jax.experimental.pallas.pallas_call() 的文档,以及 网格和BlockSpecs 以获取参数的更详细描述。

参数:
  • grid (TupleGrid)

  • in_specs (BlockSpecTree)

  • out_specs (BlockSpecTree)

__init__(grid=(), in_specs=NoBlockSpec, out_specs=NoBlockSpec)[源代码][源代码]#
参数:
  • grid (Grid)

  • in_specs (BlockSpecTree)

  • out_specs (BlockSpecTree)

方法

__init__([grid, in_specs, out_specs])

属性

grid

grid_names

in_specs

out_specs