基于 JAX 构建#
学习高级 JAX 使用的一个好方法是看看其他库是如何使用 JAX 的,包括它们如何将库集成到它们的 API 中,它在数学上增加了什么功能,以及它在其他库中如何用于计算加速。
以下是一些示例,展示了如何在多个领域和软件包中使用JAX的功能来定义加速计算。
梯度计算#
简单的梯度计算是JAX的一个关键特性。在 JaxOpt库 中,值和梯度直接被用于 其源代码 中的多种优化算法。
同样地,上述提到的相同 Dynamax Optax 配对是梯度使能估计方法的一个例子,这些方法在历史上是具有挑战性的 使用 Optax 的最大似然期望。
在多个设备上的单核计算加速#
在JAX中定义的模型可以通过编译来实现通过JIT编译的单次计算加速。相同的编译代码随后可以发送到CPU设备、GPU或TPU设备以获得额外的加速,通常不需要额外的更改。这使得从开发到生产的流程变得顺畅。在Dynamax中,线性状态空间模型求解器的计算密集部分已经被jitted。一个更复杂的例子来自PyTensor,它动态编译一个JAX函数,然后jit构建的函数。
使用并行化的单机和多机加速#
JAX 的另一个好处是使用 pmap
和 vmap
函数调用或装饰器来并行化计算的简单性。在 Dynamax 状态空间模型中,通过 VMAP 装饰器 并行化,这种用例的一个实际例子是多目标跟踪。
将 JAX 代码整合到您或您用户的工作流程中#
JAX 非常灵活,可以用多种方式使用。JAX 可以用于独立模式,用户自己定义所有的计算。然而,还有其他模式,例如使用基于 jax 构建的库来提供特定功能。这些库可以定义特定类型的模型,如神经网络或状态空间模型等,或提供特定的功能,如优化。以下是每种模式的更具体示例。
直接使用#
Jax 可以直接导入并用于构建“从头开始”的模型,如本网站所示,例如在 JAX 教程 或 使用 JAX 的神经网络 中。如果你找不到特定挑战的预构建代码,或者你希望减少代码库中的依赖项数量,这可能是最佳选择。
可组合的特定领域库,使用 JAX 暴露#
另一种常见的方法是提供预构建功能的包,无论是模型定义还是某种类型的计算。这些包的组合可以混合和匹配,以实现从定义模型到估计其参数的完整端到端工作流程。
一个例子是 Flax,它简化了神经网络的构建。Flax 通常与 Optax 配对使用,其中 Flax 定义神经网络架构,而 Optax 提供优化和模型拟合能力。
另一个是 Dynamax,它允许轻松定义状态空间模型。使用 Dynamax,可以使用 Optax 进行最大似然估计 或使用 Blackjax 的 MCMC 进行全贝叶斯后验估计。