jax.device_get

目录

jax.device_get#

jax.device_get(x)[源代码][源代码]#

x 传输到主机。

如果 x 是一个 pytree,那么各个缓冲区将并行复制。

参数:

x (Any) – 一个数组、标量、Array 或(嵌套的)标准 Python 容器,表示要传输到主机的数组。

返回:

一个数组或(嵌套的)Python容器,表示 x 的值。

示例

传递一个数组:

>>> import jax
>>> x = jax.numpy.array([1., 2., 3.])
>>> jax.device_get(x)
array([1., 2., 3.], dtype=float32)

传递一个标量(无效果):

>>> jax.device_get(1)
1

参见

  • device_put

  • device_put_sharded

  • device_put_replicated