基于 JAX 构建#

学习高级 JAX 使用的一个好方法是看看其他库是如何使用 JAX 的,包括它们如何将库集成到它们的 API 中,它在数学上增加了什么功能,以及它在其他库中如何用于计算加速。

以下是一些示例,展示了如何在多个领域和软件包中使用JAX的功能来定义加速计算。

梯度计算#

简单的梯度计算是JAX的一个关键特性。在 JaxOpt库 中,值和梯度直接被用于 其源代码 中的多种优化算法。

同样地,上述提到的相同 Dynamax Optax 配对是梯度使能估计方法的一个例子,这些方法在历史上是具有挑战性的 使用 Optax 的最大似然期望

在多个设备上的单核计算加速#

在JAX中定义的模型可以通过编译来实现通过JIT编译的单次计算加速。相同的编译代码随后可以发送到CPU设备、GPU或TPU设备以获得额外的加速,通常不需要额外的更改。这使得从开发到生产的流程变得顺畅。在Dynamax中,线性状态空间模型求解器的计算密集部分已经被jitted。一个更复杂的例子来自PyTensor,它动态编译一个JAX函数,然后jit构建的函数

使用并行化的单机和多机加速#

JAX 的另一个好处是使用 pmapvmap 函数调用或装饰器来并行化计算的简单性。在 Dynamax 状态空间模型中,通过 VMAP 装饰器 并行化,这种用例的一个实际例子是多目标跟踪。

将 JAX 代码整合到您或您用户的工作流程中#

JAX 非常灵活,可以用多种方式使用。JAX 可以用于独立模式,用户自己定义所有的计算。然而,还有其他模式,例如使用基于 jax 构建的库来提供特定功能。这些库可以定义特定类型的模型,如神经网络或状态空间模型等,或提供特定的功能,如优化。以下是每种模式的更具体示例。

直接使用#

Jax 可以直接导入并用于构建“从头开始”的模型,如本网站所示,例如在 JAX 教程使用 JAX 的神经网络 中。如果你找不到特定挑战的预构建代码,或者你希望减少代码库中的依赖项数量,这可能是最佳选择。

可组合的特定领域库,使用 JAX 暴露#

另一种常见的方法是提供预构建功能的包,无论是模型定义还是某种类型的计算。这些包的组合可以混合和匹配,以实现从定义模型到估计其参数的完整端到端工作流程。

一个例子是 Flax,它简化了神经网络的构建。Flax 通常与 Optax 配对使用,其中 Flax 定义神经网络架构,而 Optax 提供优化和模型拟合能力。

另一个是 Dynamax,它允许轻松定义状态空间模型。使用 Dynamax,可以使用 Optax 进行最大似然估计 或使用 Blackjax 的 MCMC 进行全贝叶斯后验估计。

JAX 对用户完全隐藏#

其他库选择在其模型特定的API中完全封装JAX。例如,PyMC和Pytensor,用户可能永远不会“直接看到”JAX,而是通过PyMC特定的API封装JAX函数