jax.device_put_replicated#
- jax.device_put_replicated(x, devices)[源代码][源代码]#
将数组传输到每个指定的设备并形成数组。
- 参数:
x (Any) – 一个数组、标量或(嵌套的)标准 Python 容器,表示要复制以形成输出的数组。
devices (Sequence[xc.Device]) – 表示将
x
转移到的设备
实例序列。
此函数总是异步的,即立即返回。
- 返回:
一个数组或(嵌套的)Python容器,表示
x
的值沿着一个新的前导轴广播,该轴的大小为len(devices)
,其中沿着该新前导轴的每个切片都由devices
中相应条目指定的设备上的内存支持。- 参数:
x (Any)
devices (Sequence[xc.Device])
示例
传递一个数组:
>>> import jax >>> devices = jax.local_devices() >>> x = jax.numpy.array([1., 2., 3.]) >>> y = jax.device_put_replicated(x, devices) >>> np.allclose(y, jax.numpy.stack([x for _ in devices])) True
参见
device_put
device_put_sharded