jax.device_put_sharded#
- jax.device_put_sharded(shards, devices)[源代码][源代码]#
将数组分片传输到指定设备并形成数组。
- 参数:
shards (Sequence[Any]) – 一系列数组、标量或(嵌套的)标准 Python 容器,表示要堆叠在一起形成输出的分片。
shards
的长度必须等于devices
的长度。devices (Sequence[xc.Device]) – 表示要将
shards
中的相应分片转移到的Device
实例序列。
此函数总是异步的,即立即返回。
- 返回:
一个数组或(嵌套的)Python容器,表示
shards
的元素堆叠在一起,每个分片由devices
中相应条目指定的物理设备内存支持。- 参数:
shards (Sequence[Any])
devices (Sequence[xc.Device])
示例
传递一个数组的列表给
shards
会得到一个包含输入堆叠版本的碎片化数组:>>> import jax >>> devices = jax.local_devices() >>> x = [jax.numpy.ones(5) for device in devices] >>> y = jax.device_put_sharded(x, devices) >>> np.allclose(y, jax.numpy.stack(x)) True
将嵌套容器对象列表传递给
shards
,其中叶节点为数组,相当于在每个叶节点堆叠分片。这要求列表中的所有条目具有相同的树结构:>>> x = [(i, jax.numpy.arange(i, i + 4)) for i in range(len(devices))] >>> y = jax.device_put_sharded(x, devices) >>> type(y) <class 'tuple'> >>> y0 = jax.device_put_sharded([a for a, b in x], devices) >>> y1 = jax.device_put_sharded([b for a, b in x], devices) >>> np.allclose(y[0], y0) True >>> np.allclose(y[1], y1) True
参见
device_put
device_put_replicated