jax.默认设备

jax.默认设备#

jax.default_device = <jax._src.config.State object>#

jax_default_device 配置选项的上下文管理器。

配置 JAX 操作的默认设备。设置为 Device 对象(例如 jax.devices("cpu")[0])以使用该设备作为 JAX 操作和 jit 函数调用的默认设备(对多设备计算没有影响,例如 pmapped 函数调用)。设置为 None 以使用系统默认设备。有关设备放置的更多信息,请参阅 控制数据和计算在设备上的放置

参数:

new_val (Any)