理解 Jaxprs#

更新时间:2020年5月3日(针对提交 f1a46fe)。

从概念上讲,可以将 JAX 变换视为首先将待变换的 Python 函数追踪特化为一种小巧且行为良好的中间形式,然后使用特定于变换的解释规则进行解释。JAX 能够在如此小的软件包中打包如此强大的功能的原因之一是,它从熟悉的灵活编程接口(Python 和 NumPy)开始,并使用实际的 Python 解释器来完成大部分繁重的工作,将计算的本质提炼为一种简单的静态类型表达式语言,具有有限的高阶功能。这种语言就是 jaxpr 语言。

并非所有 Python 程序都可以这样处理,但事实证明,许多科学计算和机器学习程序可以。

在我们继续之前,重要的是要指出,并非所有的 JAX 变换都会像上面描述的那样实际生成一个 jaxpr;有些,例如,微分或批处理,会在追踪过程中逐步应用变换。尽管如此,如果一个人想要理解 JAX 如何在内部工作,或者想要利用 JAX 追踪的结果,理解 jaxprs 是有用的。

jaxpr 实例表示一个具有一个或多个类型化参数(输入变量)和一个或多个类型化结果的函数。结果仅依赖于输入变量;没有从封闭作用域捕获的自由变量。输入和输出具有类型,在 JAX 中表示为抽象值。代码中有两种相关的表示形式,即 jax.core.Jaxprjax.core.ClosedJaxprjax.core.ClosedJaxpr 表示部分应用的 jax.core.Jaxpr,并且是当你使用 jax.make_jaxpr() 检查 jaxprs 时获得的内容。它具有以下字段:

  • jaxpr 是一个 jax.core.Jaxpr,表示函数的实际计算内容(如下所述)。

  • consts 是一个常量列表。

ClosedJaxpr 中最有趣的部分是实际的执行内容,它表示为一个 jax.core.Jaxpr,使用以下语法打印:

Jaxpr ::= { lambda Var* ; Var+. let
              Eqn*
            in  [Expr+] }
哪里:
  • jaxpr 的参数显示为两个由 ; 分隔的变量列表。第一组变量是为被提升出来的常量引入的占位符。这些被称为 constvars,在 jax.core.ClosedJaxpr 中,consts 字段保存相应的值。第二组变量,称为 invars,对应于被跟踪的 Python 函数的输入。

  • Eqn* 是一个方程列表,定义了引用中间表达式的中间变量。每个方程定义一个或多个变量作为对某些原子表达式应用原语的结果。每个方程仅使用输入变量和由先前方程定义的中间变量。

  • Expr+: 是 jaxpr 的输出原子表达式(字面量或变量)列表。

方程式如下所示:

Eqn  ::= Var+ = Primitive [ Param* ] Expr+
哪里:
  • Var+ 是一个或多个中间变量,定义为原语调用的输出(某些原语可以返回多个值)。

  • Expr+ 是一个或多个原子表达式,每个表达式可以是变量或字面常量。一个特殊的变量 unitvar 或字面 unit,打印为 *,表示在计算的其余部分中不需要的值,并且已被省略。也就是说,单位只是占位符。

  • Param* 是原语的零个或多个命名参数,打印在方括号中。每个参数显示为 Name = Value

大多数 jaxpr 原语是一阶的(它们只接受一个或多个 Expr 作为参数):

Primitive := add | sub | sin | mul | ...

jaxpr 原语在 jax.lax 模块中进行了文档化。

例如,下面是为函数 func1 生成的 jaxpr

>>> from jax import make_jaxpr
>>> import jax.numpy as jnp
>>> def func1(first, second):
...    temp = first + jnp.sin(second) * 3.
...    return jnp.sum(temp)
...
>>> print(make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8)))
{ lambda ; a:f32[8] b:f32[8]. let
    c:f32[8] = sin b
    d:f32[8] = mul c 3.0
    e:f32[8] = add a d
    f:f32[] = reduce_sum[axes=(0,)] e
  in (f,) }

这里没有常量变量,ab 是输入变量,它们分别对应于 firstsecond 函数参数。标量字面量 3.0 保持内联。reduce_sum 原语有一个命名参数 axes,除了操作数 e

请注意,尽管调用 JAX 的程序执行会构建一个 jaxpr,但 Python 级别的控制流和 Python 级别的函数会正常执行。这意味着,仅仅因为一个 Python 程序包含函数和控制流,生成的 jaxpr 并不一定需要包含控制流或高阶特性。

例如,当追踪函数 func3 时,JAX 将会内联对 inner 的调用以及条件 if second.shape[0] > 4,并且将生成与之前相同的 jaxpr。

>>> def func2(inner, first, second):
...   temp = first + inner(second) * 3.
...   return jnp.sum(temp)
...
>>> def inner(second):
...   if second.shape[0] > 4:
...     return jnp.sin(second)
...   else:
...     assert False
...
>>> def func3(first, second):
...   return func2(inner, first, second)
...
>>> print(make_jaxpr(func3)(jnp.zeros(8), jnp.ones(8)))
{ lambda ; a:f32[8] b:f32[8]. let
    c:f32[8] = sin b
    d:f32[8] = mul c 3.0
    e:f32[8] = add a d
    f:f32[] = reduce_sum[axes=(0,)] e
  in (f,) }

处理 PyTrees#

在 jaxpr 中没有元组类型;相反,原语接受多个输入并产生多个输出。当处理具有结构化输入或输出的函数时,JAX 会将这些输入或输出展平,在 jaxpr 中它们将显示为输入和输出的列表。更多详情,请参阅 PyTrees 的文档 (Pytrees)。

例如,以下代码生成了一个与之前看到的相同的jaxpr(有两个输入变量,每个输入元组的元素对应一个)

>>> def func4(arg):  # Arg is a pair
...   temp = arg[0] + jnp.sin(arg[1]) * 3.
...   return jnp.sum(temp)
...
>>> print(make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8))))
{ lambda ; a:f32[8] b:f32[8]. let
    c:f32[8] = sin b
    d:f32[8] = mul c 3.0
    e:f32[8] = add a d
    f:f32[] = reduce_sum[axes=(0,)] e
  in (f,) }

常量变量#

jaxprs 中的一些值是常量,因为它们的值不依赖于 jaxpr 的参数。当这些值是标量时,它们直接在 jaxpr 方程中表示;非标量数组常量则被提升到顶层 jaxpr 中,它们对应于常量变量(“constvars”)。这些 constvars 与其他 jaxpr 参数(“invars”)的区别仅在于簿记惯例。

高阶原语#

jaxpr 包含几个高阶原语。它们更复杂,因为它们包含子 jaxpr。

条件语句#

JAX 会跟踪普通的 Python 条件语句。为了捕获用于动态执行的条件表达式,必须使用 jax.lax.switch()jax.lax.cond() 构造函数,它们的签名如下:

lax.switch(index: int, branches: Sequence[A -> B], operand: A) -> B

lax.cond(pred: bool, true_body: A -> B, false_body: A -> B, operand: A) -> B

这两个都会在内部绑定一个称为 cond 的基本操作。jaxprs 中的 cond 基本操作反映了 lax.switch() 的更一般签名:它接受一个整数,表示要执行的分支的索引(被限制在有效的索引范围内)。

例如:

>>> from jax import lax
>>>
>>> def one_of_three(index, arg):
...   return lax.switch(index, [lambda x: x + 1.,
...                             lambda x: x - 2.,
...                             lambda x: x + 3.],
...                     arg)
...
>>> print(make_jaxpr(one_of_three)(1, 5.))
{ lambda ; a:i32[] b:f32[]. let
    c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
    d:i32[] = clamp 0 c 2
    e:f32[] = cond[
      branches=(
        { lambda ; f:f32[]. let g:f32[] = add f 1.0 in (g,) }
        { lambda ; h:f32[]. let i:f32[] = sub h 2.0 in (i,) }
        { lambda ; j:f32[]. let k:f32[] = add j 3.0 in (k,) }
      )
    ] d b
  in (e,) }

cond 原语的 branches 参数对应于分支功能。在这个例子中,这些功能每个都接受一个输入变量,对应于 x

上述 cond 原语的实例接受两个操作数。第一个 (d) 是分支索引,然后 b 是操作数 (arg),将被传递给 branches 中由分支索引选择的 jaxpr。

另一个例子,使用 lax.cond():

>>> from jax import lax
>>>
>>> def func7(arg):
...   return lax.cond(arg >= 0.,
...                   lambda xtrue: xtrue + 3.,
...                   lambda xfalse: xfalse - 3.,
...                   arg)
...
>>> print(make_jaxpr(func7)(5.))
{ lambda ; a:f32[]. let
    b:bool[] = ge a 0.0
    c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    d:f32[] = cond[
      branches=(
        { lambda ; e:f32[]. let f:f32[] = sub e 3.0 in (f,) }
        { lambda ; g:f32[]. let h:f32[] = add g 3.0 in (h,) }
      )
    ] c a
  in (d,) }

在这种情况下,布尔谓词被转换为整数索引(0或1),而 branches 是对应于假分支和真分支功能的jaxprs,按此顺序。同样,每个功能都接受一个输入变量,分别对应于 xfalsextrue

以下示例展示了一个更复杂的情况,当分支功能的输入是一个元组时,false 分支功能包含一个常量 jnp.ones(1),该常量被提升为一个 constvar

>>> def func8(arg1, arg2):  # arg2 is a pair
...   return lax.cond(arg1 >= 0.,
...                   lambda xtrue: xtrue[0],
...                   lambda xfalse: jnp.array([1]) + xfalse[1],
...                   arg2)
...
>>> print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.)))
{ lambda a:i32[1]; b:f32[] c:f32[1] d:f32[]. let
    e:bool[] = ge b 0.0
    f:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
    g:f32[1] = cond[
      branches=(
        { lambda ; h:i32[1] i:f32[1] j:f32[]. let
            k:f32[1] = convert_element_type[new_dtype=float32 weak_type=True] h
            l:f32[1] = add k j
          in (l,) }
        { lambda ; m_:i32[1] n:f32[1] o:f32[]. let  in (n,) }
      )
    ] f a c d
  in (g,) }

虽然#

就像条件语句一样,Python 循环在追踪过程中会被内联。如果你想捕获一个循环以进行动态执行,你必须使用几种特殊操作之一,:py:func:`jax.lax.while_loop`(一个原语)和 :py:func:`jax.lax.fori_loop`(一个生成 while_loop 原语的辅助函数):

lax.while_loop(cond_fun: (C -> bool), body_fun: (C -> C), init: C) -> C
lax.fori_loop(start: int, end: int, body: (int -> C -> C), init: C) -> C

在上面的签名中,“C” 代表循环 “carry” 值的类型。例如,这里是一个 fori 循环的例子

>>> import numpy as np
>>>
>>> def func10(arg, n):
...   ones = jnp.ones(arg.shape)  # A constant
...   return lax.fori_loop(0, n,
...                        lambda i, carry: carry + ones * 3. + arg,
...                        arg + ones)
...
>>> print(make_jaxpr(func10)(np.ones(16), 5))
{ lambda ; a:f32[16] b:i32[]. let
    c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
    d:f32[16] = add a c
    _:i32[] _:i32[] e:f32[16] = while[
      body_jaxpr={ lambda ; f:f32[16] g:f32[16] h:i32[] i:i32[] j:f32[16]. let
          k:i32[] = add h 1
          l:f32[16] = mul f 3.0
          m:f32[16] = add j l
          n:f32[16] = add m g
        in (k, i, n) }
      body_nconsts=2
      cond_jaxpr={ lambda ; o:i32[] p:i32[] q:f32[16]. let
          r:bool[] = lt o p
        in (r,) }
      cond_nconsts=0
    ] c a 0 b d
  in (e,) }

while 原语需要 5 个参数:c a 0 b d,如下所示:

  • cond_jaxpr 的常量(因为 cond_nconsts 是 0)

  • body_jaxpr 的 2 个常量 (c, 和 a)

  • 初始进位的3个参数

扫描#

JAX 支持一种特殊形式的循环,用于遍历数组的元素(形状是静态已知的)。由于迭代次数是固定的,这种形式的循环很容易进行反向微分。这种循环是通过 jax.lax.scan() 函数构建的:

lax.scan(body_fun: (C -> A -> (C, B)), init_carry: C, in_arr: Array[A]) -> (C, Array[B])

这是以 Haskell 类型签名 的形式编写的:C 是扫描进位的类型,A 是输入数组的元素类型,B 是输出数组的元素类型。

以以下函数 func11 为例

>>> def func11(arr, extra):
...   ones = jnp.ones(arr.shape)  #  A constant
...   def body(carry, aelems):
...     # carry: running dot-product of the two arrays
...     # aelems: a pair with corresponding elements from the two arrays
...     ae1, ae2 = aelems
...     return (carry + ae1 * ae2 + extra, carry)
...   return lax.scan(body, 0., (arr, ones))
...
>>> print(make_jaxpr(func11)(np.ones(16), 5.))
{ lambda ; a:f32[16] b:f32[]. let
    c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
    d:f32[] e:f32[16] = scan[
      _split_transpose=False
      jaxpr={ lambda ; f:f32[] g:f32[] h:f32[] i:f32[]. let
          j:f32[] = mul h i
          k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] g
          l:f32[] = add k j
          m:f32[] = convert_element_type[new_dtype=float32 weak_type=False] f
          n:f32[] = add l m
        in (n, g) }
      length=16
      linear=(False, False, False, False)
      num_carry=1
      num_consts=1
      reverse=False
      unroll=1
    ] b 0.0 a c
  in (d, e) }

linear 参数描述了每个输入变量是否保证在线性体中使用。一旦扫描通过线性化,更多的参数将会是线性的。

扫描原语需要4个参数:b 0.0 a c,其中:

  • one 是主体的自由变量

  • one 是进位的初始值

  • 接下来的两个是扫描操作所涉及的数组。

XLA_调用#

调用原语源自JIT编译,它封装了一个子jaxpr以及指定计算应运行的后端和设备的参数。例如

>>> from jax import jit
>>>
>>> def func12(arg):
...   @jit
...   def inner(x):
...     return x + arg * jnp.ones(1)  # Include a constant in the inner function
...   return arg + inner(arg - 2.)
...
>>> print(make_jaxpr(func12)(1.))  
{ lambda ; a:f32[]. let
    b:f32[] = sub a 2.0
    c:f32[1] = pjit[
      name=inner
      jaxpr={ lambda ; d:f32[] e:f32[]. let
          f:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0
          g:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
          h:f32[1] = mul g f
          i:f32[] = convert_element_type[new_dtype=float32 weak_type=False] e
          j:f32[1] = add i h
        in (j,) }
    ] a b
    k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a
    l:f32[1] = add k c
  in (l,) }

XLA_pmap#

如果你使用 jax.pmap() 转换,将要映射的函数是通过 xla_pmap 原语捕获的。考虑这个例子

>>> from jax import pmap
>>>
>>> def func13(arr, extra):
...   def inner(x):
...     # use a free variable "extra" and a constant jnp.ones(1)
...     return (x + extra + jnp.ones(1)) / lax.psum(x, axis_name='rows')
...   return pmap(inner, axis_name='rows')(arr)
...
>>> print(make_jaxpr(func13)(jnp.ones((1, 3)), 5.))
{ lambda ; a:f32[1,3] b:f32[]. let
    c:f32[1,3] = xla_pmap[
      axis_name=rows
      axis_size=1
      backend=None
      call_jaxpr={ lambda ; d:f32[] e:f32[3]. let
          f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
          g:f32[3] = add e f
          h:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0
          i:f32[3] = add g h
          j:f32[3] = psum[axes=('rows',) axis_index_groups=None] e
          k:f32[3] = div i j
        in (k,) }
      devices=None
      donated_invars=(False, False)
      global_axis_size=1
      in_axes=(None, 0)
      is_explicit_global_axis_size=False
      name=inner
      out_axes=(0,)
    ] b a
  in (c,) }

xla_pmap 原语指定了轴的名称(参数 axis_name)以及要映射的函数体作为 call_jaxpr 参数。该参数的值是一个包含2个输入变量的Jaxpr。

参数 in_axes 指定哪些输入变量应被映射,哪些应被广播。在我们的示例中,extra 的值被广播,而 arr 的值被映射。