jax.device_put_replicated

jax.device_put_replicated#

jax.device_put_replicated(x, devices)[源代码][源代码]#

将数组传输到每个指定的设备并形成数组。

参数:
  • x (Any) – 一个数组、标量或(嵌套的)标准 Python 容器,表示要复制以形成输出的数组。

  • devices (Sequence[xc.Device]) – 表示将 x 转移到的 设备 实例序列。

此函数总是异步的,即立即返回。

返回:

一个数组或(嵌套的)Python容器,表示 x 的值沿着一个新的前导轴广播,该轴的大小为 len(devices),其中沿着该新前导轴的每个切片都由 devices 中相应条目指定的设备上的内存支持。

参数:
  • x (Any)

  • devices (Sequence[xc.Device])

示例

传递一个数组:

>>> import jax
>>> devices = jax.local_devices()
>>> x = jax.numpy.array([1., 2., 3.])
>>> y = jax.device_put_replicated(x, devices)
>>> np.allclose(y, jax.numpy.stack([x for _ in devices]))
True

参见

  • device_put

  • device_put_sharded