jax.device_get#
- jax.device_get(x)[源代码][源代码]#
将
x
传输到主机。如果
x
是一个 pytree,那么各个缓冲区将并行复制。- 参数:
x (Any) – 一个数组、标量、Array 或(嵌套的)标准 Python 容器,表示要传输到主机的数组。
- 返回:
一个数组或(嵌套的)Python容器,表示
x
的值。
示例
传递一个数组:
>>> import jax >>> x = jax.numpy.array([1., 2., 3.]) >>> jax.device_get(x) array([1., 2., 3.], dtype=float32)
传递一个标量(无效果):
>>> jax.device_get(1) 1
参见
device_put
device_put_sharded
device_put_replicated