jax.debug.检查_数组_分片

jax.debug.检查_数组_分片#

jax.debug.inspect_array_sharding(value, *, callback)[源代码][源代码]#

启用检查JIT函数内的数组分片。

这个函数在提供一个数组的 Pytree 时,会回调每个数组的切分,并在 pjit 计算中工作,从而可以检查所选择的中间切分。

callback 被调用的策略是 尽可能早 在分片信息可用时。这意味着如果 inspect_array_callback 在没有进行任何转换的情况下被调用,回调将立即发生,因为我们已经准备好数组及其分片。在 jax.jit 内部,回调将在降低时间发生,这意味着你可以使用 AOT API(jit(f).lower(...))触发回调。在 pjit 内部,回调发生在 编译时,因为分片由 XLA 决定。你可以通过使用 JAX 的 AOT API(pjit(f).lower(...).compile())触发回调。在所有情况下,回调将通过运行函数来触发,因为运行函数首先需要降低和编译。然而,一旦函数被编译并缓存,回调将不再发生。

此功能是实验性的,其行为未来可能会发生变化。

参数:
  • value – JAX 数组的 Pytree。

  • callback (Callable[[Sharding], None]) – 一个接受 Sharding 且不返回值的可调用对象。

在下面的例子中,我们打印出在 pjit 计算中一个中间值的分片情况:

>>> import jax
>>> import jax.numpy as jnp
>>> from jax.experimental.pjit import pjit
>>> from jax.sharding import Mesh, PartitionSpec
>>>
>>> x = jnp.arange(8, dtype=jnp.float32)
>>> def f_(x):
...   x = jnp.sin(x)
...   jax.debug.inspect_array_sharding(x, callback=print)
...   return jnp.square(x)
>>> f = pjit(f_, in_shardings=PartitionSpec('dev'),
...          out_shardings=PartitionSpec('dev'))
>>> with Mesh(jax.devices(), ('dev',)):
...   f.lower(x).compile()  
...
NamedSharding(mesh={'dev': 8}, partition_spec=PartitionSpec(('dev',),))