jax.lax.平台依赖

jax.lax.平台依赖#

jax.lax.platform_dependent(*args, default=None, **per_platform)[源代码][源代码]#

移除平台特定的代码。

在JAX中,计算实际运行的平台是在非常晚的时候确定的,例如,基于数据所在的位置。当使用AOT降低或序列化时,计算可能会在不同的机器上编译和执行,甚至可能在降低时不存在的平台上执行。这意味着使用Python条件编写依赖于平台的代码是不安全的,例如,基于当前默认的JAX平台。相反,可以使用 platform_dependent

用法:

def cpu_code(*args): ...
def tpu_code(*args): ...
def other_platforms_code(*args): ...
res = platform_dependent(*args, cpu=cpu_code, tpu=tpu_code,
                         default=other_platforms_code)

当在CPU上执行暂存的代码时,这等同于 cpu_code(*args),在TPU上等同于 tpu_code(*args),在其他平台上等同于 other_platforms_code(*args)。与Python条件语句不同,所有备选方案都会被追踪并暂存到Jaxpr中。这与 switch() 类似,并且是基于它实现的,因此继承了在变换下的行为。

switch() 不同,执行的选择是在更早的时候做出的:在大多数情况下,在降低平台已知时进行;在罕见的多平台降低和序列化情况下,StableHLO 代码将包含一个关于实际平台的条件。这个条件在编译平台已知之前及时解决。这意味着编译器实际上从未看到一个条件。

参数:
  • *args (Any) – 传递给每个分支的 JAX 数组。可以是 PyTrees。

  • **per_platform (Callable[..., _T]) – 用于不同平台的分支。这些分支是使用 *args 调用的 JAX 可调用对象。关键词是平台名称,例如 ‘cpu’、’tpu’、’cuda’、’rocm’。

  • default (Callable[..., _T] | None) – 可选的默认分支,用于未在 per_platform 中提及的平台。如果没有 default,则在为 per_platform 中未提及的平台降低代码时会出现错误。

返回:

per_platform[execution_platform](*args)