jax.lib 模块

jax.lib 模块#

jax.lib 包是一组内部工具和类型,用于连接 JAX 的 Python 前端和其 XLA 后端。

jax.lib.xla_bridge#

default_backend()

返回默认XLA后端的平台名称。

get_backend([platform])

get_compile_options(num_replicas, num_partitions)

返回根据标志值推导出的编译选项。

jax.lib.xla_client#

register_custom_call_target(name, fn[, ...])

注册一个自定义调用目标。