jax.experimental.multihost_utils.host_local_array_to_global_array

jax.experimental.multihost_utils.host_local_array_to_global_array#

jax.experimental.multihost_utils.host_local_array_to_global_array(local_inputs, global_mesh, pspecs)[源代码][源代码]#

将主机本地值转换为全局分片的 jax.Array。

此函数获取主机本地数据(这些数据在不同主机之间可能不同),并使用这些数据填充一个全局数组,其中每个主机上的每个设备根据 global_mesh/pspects 定义的分片获取数据的适当切片。

例如:

>>> global_mesh = jax.sharding.Mesh(jax.devices(), 'x')
>>> pspecs = jax.sharding.PartitionSpec('x')
>>> host_id = jax.process_index()
>>> arr = host_local_array_to_global_array(np.arange(4) * host_id, mesh, pspecs)  # NB: assumes jax.local_device_count() divides 4.   

生成的数组将具有形状 (4 * num_processes),并且将分布有以下值:(0, 1, 2, 3, 0, 2, 4, 6, 0, 3, 6, 9, … ),其中每个切片 np.arange(4) * host_id 将在相应主机的设备上进行分区。

同样地:

>>> mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(jax.process_count(), jax.local_device_count()), ['host', 'dev'])
>>> pspecs = jax.sharding.PartitionSpec('host')
>>> host_id = jax.process_index()
>>> arr = host_local_array_to_global_array(np.arange(4) * host_id, mesh, pspecs)  

将创建相同的分布值(0, 1, 2, 3, 0, 2, 4, 6, …),然而每个切片 np.arange(4) * i 将在相应的宿主设备上 复制

另一方面,如果 pspecs = PartitionSpec(),这意味着在所有轴上进行复制,那么这段代码:

>>> pspecs = jax.sharding.PartitionSpec()
>>> arr = host_local_array_to_global_array(np.arange(4), mesh, pspecs)  

将具有形状 (4,) 并且值 (0, 1, 2, 3) 将被复制到所有主机和设备上。

当 pspec 指示数据复制时,具有不相同的 local_inputs 是一种未定义行为。

你可以使用此函数来转换到 jax.Array。使用 jax.Array 与 pjit 具有与使用 GDA 与 pjit 相同的语义,即所有传递给 pjit 的 jax.Array 输入都应该是全局形状的。

如果你当前正在将主机本地值传递给 pjit,你可以使用此函数将你的主机本地值转换为全局数组,然后将该数组传递给 pjit。

示例用法。

>>> from jax.experimental import multihost_utils 
>>>
>>> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs) 
>>>
>>> with mesh: 
>>>   global_out = pjitted_fun(global_inputs) 
>>>
>>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs) 

请注意,此功能要求全局网格为连续网格,这意味着每个主机所属的设备应在此网格中形成一个子立方体。要将本地数据移动到非连续网格的全局数组中,请使用 jax.make_array_from_callback 或 jax.make_array_from_single_device_arrays 代替。

参数:
  • local_inputs (Any) – 主机本地值的Pytree。

  • global_mesh (jax.sharding.Mesh) – 一个 jax.sharding.Mesh 对象。网格必须是一个连续的网格。

  • mesh. (that is all hosts' devices must form a subcube in this)

  • pspecs (Any) – 一个由 jax.sharding.PartitionSpec 组成的 Pytree。

返回:

全局数组的pytree。