jax.device_put_sharded

jax.device_put_sharded#

jax.device_put_sharded(shards, devices)[源代码][源代码]#

将数组分片传输到指定设备并形成数组。

参数:
  • shards (Sequence[Any]) – 一系列数组、标量或(嵌套的)标准 Python 容器,表示要堆叠在一起形成输出的分片。shards 的长度必须等于 devices 的长度。

  • devices (Sequence[xc.Device]) – 表示要将 shards 中的相应分片转移到的 Device 实例序列。

此函数总是异步的,即立即返回。

返回:

一个数组或(嵌套的)Python容器,表示 shards 的元素堆叠在一起,每个分片由 devices 中相应条目指定的物理设备内存支持。

参数:
  • shards (Sequence[Any])

  • devices (Sequence[xc.Device])

示例

传递一个数组的列表给 shards 会得到一个包含输入堆叠版本的碎片化数组:

>>> import jax
>>> devices = jax.local_devices()
>>> x = [jax.numpy.ones(5) for device in devices]
>>> y = jax.device_put_sharded(x, devices)
>>> np.allclose(y, jax.numpy.stack(x))
True

将嵌套容器对象列表传递给 shards ,其中叶节点为数组,相当于在每个叶节点堆叠分片。这要求列表中的所有条目具有相同的树结构:

>>> x = [(i, jax.numpy.arange(i, i + 4)) for i in range(len(devices))]
>>> y = jax.device_put_sharded(x, devices)
>>> type(y)
<class 'tuple'>
>>> y0 = jax.device_put_sharded([a for a, b in x], devices)
>>> y1 = jax.device_put_sharded([b for a, b in x], devices)
>>> np.allclose(y[0], y0)
True
>>> np.allclose(y[1], y1)
True

参见

  • device_put

  • device_put_replicated