jax.lax.with_sharding_constraint

jax.lax.with_sharding_constraint#

jax.lax.with_sharding_constraint(x, shardings)[源代码][源代码]#

在即时编译计算中限制数组分片的机制

这是GSPMD分区器的一个严格约束,而不是一个提示。关于如何使用此功能的示例,请参见 分布式数组和自动并行化

参数:
  • x – jax.Arrays 的 PyTree,其分片将受到约束

  • shardings – 分片规范的PyTree。有效值与 jax.experimental.pjit()in_shardings 参数相同。

返回:

具有指定分片约束的 jax.Arrays 的 PyTree。

返回类型:

x_with_shardings