jax.numpy.from_dlpack#
- jax.numpy.from_dlpack(x, /, *, device=None, copy=None)[源代码][源代码]#
通过 DLPack 构建 JAX 数组。
JAX 实现
numpy.from_dlpack()
。- 参数:
x (Any) – 一个通过
__dlpack__
和__dlpack_device__
方法实现 DLPack 协议的对象,或是在 CPU 或 GPU 上的传统 DLPack 张量。device (xc.Device | Sharding | None) – 一个可选的
Device
或Sharding
,表示应将返回的数组放置的单个设备。如果指定,则结果将被提交到该设备。如果未指定,则结果数组将被解包到其来源的同一设备上。将device
设置为与external_array
来源不同的设备将需要复制,这意味着copy
必须设置为True
或None
。copy (bool | None) – 一个可选的布尔值,控制是否执行复制操作。如果
copy=True
,则始终执行复制,即使解包到同一设备上。如果copy=False
,则从不执行复制,并且在必要时会引发错误。当 ``copy=None``(默认)时,如果需要进行设备转移,则可能会执行复制。
- 返回:
输入缓冲区的 JAX 数组。
- 返回类型:
备注
虽然 JAX 数组总是不可变的,但 dlpack 缓冲区不能被标记为不可变,并且外部进程有可能对其进行就地修改。如果一个 JAX 数组是从一个未复制的 dlpack 缓冲区构造的,并且源缓冲区随后被就地修改,那么在操作相关的 JAX 数组时可能会导致未定义行为。
示例
通过 DLPack 在 NumPy 和 JAX 之间传递数据:
>>> import numpy as np >>> rng = np.random.default_rng(42) >>> x_numpy = rng.random(4, dtype='float32') >>> print(x_numpy) [0.08925092 0.773956 0.6545715 0.43887842] >>> hasattr(x_numpy, "__dlpack__") # NumPy supports the DLPack interface True
>>> import jax.numpy as jnp >>> x_jax = jnp.from_dlpack(x_numpy) >>> print(x_jax) [0.08925092 0.773956 0.6545715 0.43887842] >>> hasattr(x_jax, "__dlpack__") # JAX supports the DLPack interface True
>>> x_numpy_round_trip = np.from_dlpack(x_jax) >>> print(x_numpy_round_trip) [0.08925092 0.773956 0.6545715 0.43887842]