jax.devices

目录

jax.devices#

jax.devices(backend=None)[源代码][源代码]#

返回给定后端的所有设备列表。

每个设备由 Device 的子类表示(例如 CpuDeviceGpuDevice)。返回列表的长度等于 device_count(backend)。本地设备可以通过比较 Device.process_indexjax.process_index() 返回的值来识别。

如果 backendNone,则返回默认后端的所有设备。默认后端通常是 'gpu''tpu'``(如果可用),否则是 ``'cpu'

参数:

backend (str | xla_client.Client | None) – 这是一个实验性功能,API 可能会发生变化。可选,表示 xla 后端的字符串:'cpu''gpu''tpu'

返回:

设备子类列表。

返回类型:

list[xla_client.Device]