jax.make_jaxpr

目录

jax.make_jaxpr#

jax.make_jaxpr(fun: Callable, static_argnums: int | Iterable[int] = (), axis_env: Sequence[tuple[AxisName, int]] | None = None, return_shape: Literal[False] = False, abstracted_axes: Any | None = None) Callable[..., core.ClosedJaxpr][源代码][源代码]#
jax.make_jaxpr(fun: Callable, static_argnums: int | Iterable[int] = (), axis_env: Sequence[tuple[AxisName, int]] | None = None, return_shape: Literal[True] = False, abstracted_axes: Any | None = None) Callable[..., tuple[core.ClosedJaxpr, Any]]

创建一个函数,该函数在给定示例参数的情况下生成其 jaxpr。

参数:
  • fun – 要计算其 jaxpr 的函数。其位置参数和返回值应为数组、标量或标准 Python 容器(元组/列表/字典)。

  • static_argnums – 参见 jax.jit() 的文档字符串。

  • axis_env – 可选的,一系列的成对序列,其中第一个元素是轴名称,第二个元素是一个正整数,表示具有该名称的映射轴的大小。当降低涉及并行通信集体的函数时,此参数非常有用,它指定了由 jax.pmap() 应用设置的轴名称/大小环境。

  • return_shape – 可选的布尔值,默认为 False。如果为 True,则包装的函数返回一个对,其中第一个元素是 funClosedJaxpr 表示,第二个元素是一个与 fun 输出结构相同的 pytree,其中叶子是具有 shapedtype 属性的对象,表示相应输出叶子的类型。

返回:

fun 的包装版本,当应用于示例参数时,返回 fun 在这些参数上的 ClosedJaxpr 表示。如果参数 return_shapeTrue,则返回的函数改为返回一个对,其中第一个元素是 funClosedJaxpr 表示,第二个元素是一个 pytree,表示 fun 输出的结构、形状、数据类型和命名形状。

jaxpr 是 JAX 用于程序追踪的中间表示。jaxpr 语言基于带有 let 绑定的简单类型一阶 lambda 演算。make_jaxpr() 将一个函数适配为返回其 jaxpr,我们可以检查它以理解 JAX 在内部所做的事情。返回的 jaxprfun 抽象到 ShapedArray 级别的追踪。内部存在其他抽象级别。

我们在这里不详细描述 jaxpr 语言的语义,而是给出一些例子。

>>> import jax
>>>
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
>>> print(f(3.0))
-0.83602
>>> jax.make_jaxpr(f)(3.0)
{ lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
>>> jax.make_jaxpr(jax.grad(f))(3.0)
{ lambda ; a:f32[]. let
    b:f32[] = cos a
    c:f32[] = sin a
    _:f32[] = sin b
    d:f32[] = cos b
    e:f32[] = mul 1.0 d
    f:f32[] = neg e
    g:f32[] = mul f c
  in (g,) }