jax.lax.with_sharding_constraint#
- jax.lax.with_sharding_constraint(x, shardings)[源代码][源代码]#
在即时编译计算中限制数组分片的机制
这是GSPMD分区器的一个严格约束,而不是一个提示。关于如何使用此功能的示例,请参见 分布式数组和自动并行化。
- 参数:
x – jax.Arrays 的 PyTree,其分片将受到约束
shardings – 分片规范的PyTree。有效值与
jax.experimental.pjit()
的in_shardings
参数相同。
- 返回:
具有指定分片约束的 jax.Arrays 的 PyTree。
- 返回类型:
x_with_shardings