自学:从零开始了解 JAX 核心#

https://colab.research.google.com/github/google/jax/blob/main/docs/autodidax.ipynb

是否曾想过了解 JAX 是如何工作的,但实现看起来难以理解? 那么,运气来了!通过阅读这个教程,你将学习到 JAX 核心系统中的每一个重要概念。你甚至会了解到我们独特的术语!

这是一份正在进行中的草稿。 仍有一些重要的内容缺失,将在第 5 章和第 6 章(还有更多?)中出现。这里也有一些简化的内容,我们尚未应用于主系统,但我们会做到。

第1部分:作为解释器的变换:标准求值,jvpvmap#

我们想要转化看起来像这样的函数:

def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z

sin 等函数及其背后的算术运算(muladdneg)视为原始操作,这意味着它们是原子处理单元,而不是组合。

“变换”意味着“以不同的方式解释”。我们希望重写原始操作的应用,而不是标准的解释,在这里我们将原始操作应用于数值输入以产生数值输出,而是允许不同的值在我们的程序中流动。例如,我们可能希望用每个原始操作的 JVP 规则 替代普通的应用,并让原始-切线对在我们的程序中流动。此外,我们希望能够组合多重变换,从而形成解释器的堆栈。

JAX核心机制#

我们可以实现解释器的堆栈,甚至可以在执行要转换的Python函数时动态地让它们全部释放。首先,让我们定义这些原语,以便我们可以截获它们的应用:

from typing import NamedTuple

class Primitive(NamedTuple):
  name: str

add_p = Primitive('add')
mul_p = Primitive('mul')
neg_p = Primitive("neg")
sin_p = Primitive("sin")
cos_p = Primitive("cos")
reduce_sum_p = Primitive("reduce_sum")
greater_p = Primitive("greater")
less_p = Primitive("less")
transpose_p = Primitive("transpose")
broadcast_p = Primitive("broadcast")

def add(x, y): return bind1(add_p, x, y)
def mul(x, y): return bind1(mul_p, x, y)
def neg(x): return bind1(neg_p, x)
def sin(x): return bind1(sin_p, x)
def cos(x): return bind1(cos_p, x)
def greater(x, y): return bind1(greater_p, x, y)
def less(x, y): return bind1(less_p, x, y)
def transpose(x, perm): return bind1(transpose_p, x, perm=perm)
def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes)
def reduce_sum(x, axis=None):
  if axis is None:
    axis = tuple(range(np.ndim(x)))
  if type(axis) is int:
    axis = (axis,)
  return bind1(reduce_sum_p, x, axis=axis)

def bind1(prim, *args, **params):
  out, = bind(prim, *args, **params)
  return out

我们将稍后设置数组数据类型和中缀运算符方法。

一个Primitive只是一个带有名称的对象,我们为其附加解释规则(每个变换一个)。bind函数是我们的拦截点:它将根据参数在追踪器中的封装方式和活动的解释器来确定应用哪个变换规则。

用户代码调用的函数,例如addsin,只是对bind调用的封装。这些封装使我们能够控制如何将参数传递给bind,特别是我们遵循一个方便的内部约定:当我们调用bind时,我们将表示数组数据的值作为位置参数传递,而将sum_paxis参数等元数据通过关键字传递。这个调用约定简化了一些核心逻辑(因为例如下面要定义的Tracer类的实例只能作为位置参数出现在bind中)。这些封装也可以提供文档字符串!

我们将活动解释器表示为一个栈。这个栈只是一个简单的list,每个元素是一个包含整数层级(对应于元素在栈中的高度)、解释器类型(我们称之为trace_type)和一个可选字段(供解释器需要的任何全局数据)容器。我们称每个元素为MainTrace,不过“解释器”也许更具描述性。

from collections.abc import Sequence
from contextlib import contextmanager
from typing import Any

class MainTrace(NamedTuple):
  level: int
  trace_type: type['Trace']
  global_data: Any | None

trace_stack: list[MainTrace] = []
dynamic_trace: MainTrace | None = None  # 用于第三部分

@contextmanager
def new_main(trace_type: type['Trace'], global_data=None):
  level = len(trace_stack)
  main = MainTrace(level, trace_type, global_data)
  trace_stack.append(main)

  try:
    yield main
  finally:
    trace_stack.pop()

在我们即将应用变换时,我们将使用 new_main 将另一个解释器推入栈中。然后,随着我们在函数中应用原语,我们可以认为 bind 首先是由栈顶(即最高级别)的跟踪器进行解释的。如果那个第一个解释器在其对原语的解释规则中绑定了其他原语,例如 sin_p 的 JVP 规则可能绑定 cos_pmul_p,那么那些 bind 调用将由下一层的解释器处理。

解释器栈的底部是什么?在底部,我们知道所有的变换解释器都已经完成,我们只想进行标准评估。因此,在底部我们会放置一个评估解释器。

让我们勾画出解释器的接口,这基于 TraceTracer 基类。Tracer 表示一个封装的值,可能携带一些由解释器使用的额外上下文数据。Trace 处理将值封装成 Tracers 并且还处理原语应用。

class Trace:
  main: MainTrace

  def __init__(self, main: MainTrace) -> None:
    self.main = main

  def pure(self, val): assert False  # 必须重写
  def lift(self, val): assert False  # 必须重写

  def process_primitive(self, primitive, tracers, params):
    assert False  # 必须重写

前两种方法是关于将值封装在 Tracer 中,Tracer 是在我们转换的 Python 程序中流动的对象。最后一种方法是我们将用来解释原始应用的回调。

Trace 本身不包含任何数据,除了对其对应的 MainTrace 实例的引用。实际上,在应用变换的过程中,可能会创建和丢弃多个 Trace 实例,而每次应用变换只会创建一个 MainTrace 实例。

至于 Tracer 本身,每个 Tracer 都携带一个抽象值(并将中缀运算符转发给它),其余部分由变换决定。(TracerAbstractValue 之间的关系是每个变换对应一个 Tracer,而每种基本类型(如数组)至少对应一个 AbstractValue。)

import numpy as np

class Tracer:
  _trace: Trace

  __array_priority__ = 1000

  @property
  def aval(self):
    assert False  # 必须重写

  def full_lower(self):
    return self  # 默认实现

  def __neg__(self): return self.aval._neg(self)
  def __add__(self, other): return self.aval._add(self, other)
  def __radd__(self, other): return self.aval._radd(self, other)
  def __mul__(self, other): return self.aval._mul(self, other)
  def __rmul__(self, other): return self.aval._rmul(self, other)
  def __gt__(self, other): return self.aval._gt(self, other)
  def __lt__(self, other): return self.aval._lt(self, other)
  def __bool__(self): return self.aval._bool(self)
  def __nonzero__(self): return self.aval._nonzero(self)

  def __getattr__(self, name):
    try:
      return getattr(self.aval, name)
    except AttributeError:
      raise AttributeError(f"{self.__class__.__name__} has no attribute {name}")

def swap(f): return lambda x, y: f(y, x)
class ShapedArray:
  array_abstraction_level = 1
  shape: tuple[int, ...]
  dtype: np.dtype

  def __init__(self, shape, dtype):
    self.shape = shape
    self.dtype = dtype

  @property
  def ndim(self):
    return len(self.shape)

  _neg = staticmethod(neg)
  _add = staticmethod(add)
  _radd = staticmethod(swap(add))
  _mul = staticmethod(mul)
  _rmul = staticmethod(swap(mul))
  _gt = staticmethod(greater)
  _lt = staticmethod(less)

  @staticmethod
  def _bool(tracer):
    raise Exception("ShapedArray can't be unambiguously converted to bool")

  @staticmethod
  def _nonzero(tracer):
    raise Exception("ShapedArray can't be unambiguously converted to bool")

  def str_short(self):
    return f'{self.dtype.name}[{",".join(str(d) for d in self.shape)}]'

  def __hash__(self):
    return hash((self.shape, self.dtype))

  def __eq__(self, other):
    return (type(self) is type(other) and
            self.shape == other.shape and self.dtype == other.dtype)

  def __repr__(self):
    return f"ShapedArray(shape={self.shape}, dtype={self.dtype})"

class ConcreteArray(ShapedArray):
  array_abstraction_level = 2
  val: np.ndarray

  def __init__(self, val):
    self.val = val
    self.shape = val.shape
    self.dtype = val.dtype

  @staticmethod
  def _bool(tracer):
    return bool(tracer.aval.val)

  @staticmethod
  def _nonzero(tracer):
    return bool(tracer.aval.val)

def get_aval(x):
  if isinstance(x, Tracer):
    return x.aval
  elif type(x) in jax_types:
    return ConcreteArray(np.asarray(x))
  else:
    raise TypeError(x)

jax_types = {bool, int, float,
             np.bool_, np.int32, np.int64, np.float32, np.float64, np.ndarray}

请注意,我们实际上有两个 AbstractValue 用于数组,分别代表不同的抽象层次。ShapedArray 表示具有给定形状和数据类型的所有可能数组的集合。ConcreteArray 表示一个仅包含单个数组值的单例集合。

现在我们已经设置好了解释器栈、解释器的 Trace/Tracer API 和抽象值,我们可以回过头来实现 bind

def bind(prim, *args, **params):
  top_trace = find_top_trace(args)
  tracers = [full_raise(top_trace, arg) for arg in args]
  outs = top_trace.process_primitive(prim, tracers, params)
  return [full_lower(out) for out in outs]

主要的操作是我们调用 find_top_trace 来确定哪个解释器应该处理这个原始应用。然后我们调用该顶级追踪的 process_primitive,这样追踪就可以应用它的解释规则。对 full_raise 的调用只是为了确保输入被封装在顶级追踪的 Tracer 实例中,而对 full_lower 的调用是一个可选的优化,以便尽可能地从 Tracer 中解封装值。

import operator as op

def find_top_trace(xs) -> Trace:
  top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),
                 default=trace_stack[0], key=op.attrgetter('level'))
  if dynamic_trace and dynamic_trace.level > top_main.level:
    top_main = dynamic_trace
  return top_main.trace_type(top_main)

在不考虑第3部分的dynamic_trace步骤的情况下,find_top_trace 返回与输入中的 Tracer 关联的最高级别解释器,否则返回栈底的解释器(至少目前总是评估跟踪)。这与上述描述有所偏离,我们总是从栈顶的解释器开始,然后逐步向下应用栈中的每个解释器。相反,我们仅在原始绑定的输入参数被装箱在一个与该解释器对应的 Tracer 中时应用解释器。这种优化使我们可以跳过不相关的转换,但它假设转换主要遵循数据依赖关系(除了栈底的特殊解释器,它解释所有内容)。

另一种选择是让栈中的每个解释器解释每个操作。这是值得探索的!JAX 设计的很大部分是围绕数据依赖关系,因为这对自动微分来说是非常自然的,而 JAX 的起源也在于自动微分。但这可能有些过拟合。

def full_lower(val: Any):
  if isinstance(val, Tracer):
    return val.full_lower()
  else:
    return val

def full_raise(trace: Trace, val: Any) -> Tracer:
  if not isinstance(val, Tracer):
    assert type(val) in jax_types
    return trace.pure(val)
  level = trace.main.level
  if val._trace.main is trace.main:
    return val
  elif val._trace.main.level < level:
    return trace.lift(val)
  elif val._trace.main.level > level:
    raise Exception(f"Can't lift level {val._trace.main.level} to {level}.")
  else:  # val._trace.level 等于 level
    raise Exception(f"Different traces at same level: {val._trace}, {trace}.")

full_raise中的逻辑旨在将值包装到特定TraceTracer中,根据上下文调用Trace上的不同方法:对于非Tracer常量调用Trace.pure,而对于来自低级解释器的已经是Tracer的值调用Trace.lift。这两个方法可以共享相同的实现,但通过在核心逻辑中区分它们,我们可以为Trace子类提供更多信息。

JAX核心的内容就是这些!现在我们可以开始添加解释器了。

评估解释器#

我们将从最简单的解释器开始:评估解释器,它将位于解释器栈的底部。

class EvalTrace(Trace):
  pure = lift = lambda self, x: x  # 无需在Tracer中进行装箱

  def process_primitive(self, primitive, tracers, params):
    return impl_rules[primitive](*tracers, **params)

trace_stack.append(MainTrace(0, EvalTrace, None))  # 特殊栈底

# 注意:在 JAX 中,我们不是将实现规则附加到字典上,而是将其附加到 Primitive 实例上。
impl_rules = {}

impl_rules[add_p] = lambda x, y: [np.add(x, y)]
impl_rules[mul_p] = lambda x, y: [np.multiply(x, y)]
impl_rules[neg_p] = lambda x: [np.negative(x)]
impl_rules[sin_p] = lambda x: [np.sin(x)]
impl_rules[cos_p] = lambda x: [np.cos(x)]
impl_rules[reduce_sum_p] = lambda x, *, axis: [np.sum(x, axis)]
impl_rules[greater_p] = lambda x, y: [np.greater(x, y)]
impl_rules[less_p] = lambda x, y: [np.less(x, y)]
impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)]

def broadcast_impl(x, *, shape, axes):
  for axis in sorted(axes):
    x = np.expand_dims(x, axis)
  return [np.broadcast_to(x, shape)]
impl_rules[broadcast_p] = broadcast_impl

通过这个解释器,我们可以评估用户的函数:

def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z

print(f(3.0))
2.7177599838802657

哇!就像绕着一个大圆圈走。但是这种间接的意义在于我们现在可以添加一些真实的转换。

正向模式自动微分与 jvp#

首先,一些辅助函数:

import builtins

def zeros_like(val):
  aval = get_aval(val)
  return np.zeros(aval.shape, aval.dtype)

def unzip2(pairs):
  lst1, lst2 = [], []
  for x1, x2 in pairs:
    lst1.append(x1)
    lst2.append(x2)
  return lst1, lst2

def map(f, *xs):
  return list(builtins.map(f, *xs))

def zip(*args):
  fst, *rest = args = map(list, args)
  n = len(fst)
  for arg in rest:
    assert len(arg) == n
  return list(builtins.zip(*args))

Tracer 用于前向模式自动微分,携带一个原始-切向量对。Trace 应用 JVP 规则。

class JVPTracer(Tracer):
  def __init__(self, trace, primal, tangent):
    self._trace = trace
    self.primal = primal
    self.tangent = tangent

  @property
  def aval(self):
    return get_aval(self.primal)

class JVPTrace(Trace):
  pure = lift = lambda self, val: JVPTracer(self, val, zeros_like(val))

  def process_primitive(self, primitive, tracers, params):
    primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
    jvp_rule = jvp_rules[primitive]
    primal_outs, tangent_outs = jvp_rule(primals_in, tangents_in, **params)
    return [JVPTracer(self, x, t) for x, t in zip(primal_outs, tangent_outs)]

jvp_rules = {}

注意到 purelift 包将一个值提升到 JVPTracer 中,并且上下文最少,即零切线值。

让我们为原始类型添加一些 JVP 规则:

def add_jvp(primals, tangents):
  (x, y), (x_dot, y_dot) = primals, tangents
  return [x + y], [x_dot + y_dot]
jvp_rules[add_p] = add_jvp

def mul_jvp(primals, tangents):
  (x, y), (x_dot, y_dot) = primals, tangents
  return [x * y], [x_dot * y + x * y_dot]
jvp_rules[mul_p] = mul_jvp

def sin_jvp(primals, tangents):
  (x,), (x_dot,) = primals, tangents
  return [sin(x)], [cos(x) * x_dot]
jvp_rules[sin_p] = sin_jvp

def cos_jvp(primals, tangents):
  (x,), (x_dot,) = primals, tangents
  return [cos(x)], [-sin(x) * x_dot]
jvp_rules[cos_p] = cos_jvp

def neg_jvp(primals, tangents):
  (x,), (x_dot,) = primals, tangents
  return [neg(x)], [neg(x_dot)]
jvp_rules[neg_p] = neg_jvp

def reduce_sum_jvp(primals, tangents, *, axis):
  (x,), (x_dot,) = primals, tangents
  return [reduce_sum(x, axis)], [reduce_sum(x_dot, axis)]
jvp_rules[reduce_sum_p] = reduce_sum_jvp

def greater_jvp(primals, tangents):
  (x, y), _ = primals, tangents
  out_primal = greater(x, y)
  return [out_primal], [zeros_like(out_primal)]
jvp_rules[greater_p] = greater_jvp

def less_jvp(primals, tangents):
  (x, y), _ = primals, tangents
  out_primal = less(x, y)
  return [out_primal], [zeros_like(out_primal)]
jvp_rules[less_p] = less_jvp

最后,我们添加一个转换 API 来启动追踪:

def jvp_v1(f, primals, tangents):
  with new_main(JVPTrace) as main:
    trace = JVPTrace(main)
    tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]
    out = f(*tracers_in)
    tracer_out = full_raise(trace, out)
    primal_out, tangent_out = tracer_out.primal, tracer_out.tangent
  return primal_out, tangent_out

这样,我们就可以进行微分了!

x = 3.0
y, sin_deriv_at_3 = jvp_v1(sin, (x,), (1.0,))
print(sin_deriv_at_3)
print(cos(3.0))
-0.9899924966004454
-0.9899924966004454
def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z

x, xdot = 3., 1.
y, ydot = jvp_v1(f, (x,), (xdot,))
print(y)
print(ydot)
2.7177599838802657
2.979984993200891
def deriv(f):
  return lambda x: jvp_v1(f, (x,), (1.,))[1]

print(deriv(sin)(3.))
print(deriv(deriv(sin))(3.))
print(deriv(deriv(deriv(sin)))(3.))
print(deriv(deriv(deriv(deriv(sin))))(3.))
-0.9899924966004454
-0.1411200080598672
0.9899924966004454
0.1411200080598672
def f(x):
  if x > 0.:  # Python 控制流程
    return 2. * x
  else:
    return x

print(deriv(f)(3.))
print(deriv(f)(-3.))
2.0
1.0

Pytrees 和扁平化用户函数的输入输出#

jvp_v1的一个限制是它假设用户函数接受数组作为位置参数并产生一个数组作为输出。如果它产生一个列表作为输出呢?或者接受嵌套容器作为输入呢?在栈的每一层处理所有可能的输入和输出容器将会很麻烦。相反,我们可以包装用户函数,使得包装后的版本接受数组作为输入并返回一个扁平化的数组列表作为输出。这个包装器只需要解开其输入,调用用户函数,并扁平化输出即可。

假设用户总是给我们提供接受数组作为输入并产生扁平化数组列表作为输出的函数,下面是我们想要书写的jvp

def jvp_flat(f, primals, tangents):
  with new_main(JVPTrace) as main:
    trace = JVPTrace(main)
    tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]
    outs = f(*tracers_in)
    tracers_out = [full_raise(trace, out) for out in outs]
    primals_out, tangents_out = unzip2((t.primal, t.tangent) for t in tracers_out)
  return primals_out, tangents_out

为了支持输入和输出中具有任意容器的用户函数,下面是我们如何编写面向用户的 jvp 包装器:

def jvp(f, primals, tangents):
  primals_flat, in_tree = tree_flatten(primals)
  tangents_flat, in_tree2 = tree_flatten(tangents)
  if in_tree != in_tree2: raise TypeError
  f, out_tree = flatten_fun(f, in_tree)
  primals_out_flat, tangents_out_flat = jvp_flat(f, primals_flat, tangents_flat)
  primals_out = tree_unflatten(out_tree(), primals_out_flat)
  tangents_out = tree_unflatten(out_tree(), tangents_out_flat)
  return primals_out, tangents_out

注意到我们必须将用户函数输出的树结构传回到flatten_fun的调用者。该信息在实际运行用户函数之前是不可用的,因此flatten_fun仅返回一个可变单元的引用,表示为惰性计算。这些副作用是安全的,因为我们始终只运行用户函数一次。(这种安全机制是linear_util.py中“线性”名称的原因,从线性类型的角度来看。)

剩下的就是编写tree_flattentree_unflattenflatten_fun

Hide code cell source
def flatten_fun(f, in_tree):
  store = Store()

  def flat_fun(*args_flat):
    pytree_args = tree_unflatten(in_tree, args_flat)
    out = f(*pytree_args)
    out_flat, out_tree = tree_flatten(out)
    store.set_value(out_tree)
    return out_flat

  return flat_fun, store

class Empty: pass
empty = Empty()

class Store:
  val = empty

  def set_value(self, val):
    assert self.val is empty
    self.val = val

  def __call__(self):
    return self.val
Hide code cell source
from collections.abc import Hashable, Iterable, Iterator
import itertools as it
from collections.abc import Callable

class NodeType(NamedTuple):
  name: str
  to_iterable: Callable
  from_iterable: Callable

def register_pytree_node(ty: type, to_iter: Callable, from_iter: Callable
                         ) -> None:
  node_types[ty] = NodeType(str(ty), to_iter, from_iter)

node_types: dict[type, NodeType] = {}
register_pytree_node(tuple, lambda t: (None, t), lambda _, xs: tuple(xs))
register_pytree_node(list,  lambda l: (None, l), lambda _, xs:  list(xs))
register_pytree_node(dict,
                     lambda d: map(tuple, unzip2(sorted(d.items()))),
                     lambda keys, vals: dict(zip(keys, vals)))

class PyTreeDef(NamedTuple):
  node_type: NodeType
  node_metadata: Hashable
  child_treedefs: tuple['PyTreeDef', ...]

class Leaf: pass
leaf = Leaf()

def tree_flatten(x: Any) -> tuple[list[Any], PyTreeDef]:
  children_iter, treedef = _tree_flatten(x)
  return list(children_iter), treedef

def _tree_flatten(x: Any) -> tuple[Iterable, PyTreeDef]:
  node_type = node_types.get(type(x))
  if node_type:
    node_metadata, children = node_type.to_iterable(x)
    children_flat, child_trees = unzip2(map(_tree_flatten, children))
    flattened = it.chain.from_iterable(children_flat)
    return flattened, PyTreeDef(node_type, node_metadata, tuple(child_trees))
  else:
    return [x], leaf

def tree_unflatten(treedef: PyTreeDef, xs: list[Any]) -> Any:
  return _tree_unflatten(treedef, iter(xs))

def _tree_unflatten(treedef: PyTreeDef, xs: Iterator) -> Any:
  if treedef is leaf:
    return next(xs)
  else:
    children = (_tree_unflatten(t, xs) for t in treedef.child_treedefs)
    return treedef.node_type.from_iterable(treedef.node_metadata, children)

通过这个处理 pytree 的 jvp 实现,我们现在可以处理任意的输入和输出容器。这在未来的变换中也会派上用场!

def f(x):
  y = sin(x) * 2.
  z = - y + x
  return {'hi': z, 'there': [x, y]}

x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
{'hi': np.float64(2.7177599838802657), 'there': [3.0, np.float64(0.2822400161197344)]}
{'hi': np.float64(2.979984993200891), 'there': [1.0, np.float64(-1.9799849932008908)]}

使用 vmap 的向量化批处理#

首先,几个辅助函数,一个用于从未映射的抽象值生成映射抽象值(通过移除一个轴),另一个用于移动批处理维度:

def mapped_aval(batch_dim, aval):
  shape = list(aval.shape)
  del shape[batch_dim]
  return ShapedArray(tuple(shape), aval.dtype)

def move_batch_axis(axis_size, src, dst, x):
  if src is not_mapped:
    target_shape = list(np.shape(x))
    target_shape.insert(dst, axis_size)
    return broadcast(x, target_shape, [dst])
  elif src == dst:
    return x
  else:
    return moveaxis(x, src, dst)

def moveaxis(x, src: int, dst: int):
  perm = [i for i in range(np.ndim(x)) if i != src]
  perm.insert(dst, src)
  return transpose(x, perm)

Tracer 用于向量化批处理,携带一个批处理值和一个可选整数,指示哪个轴(如果有的话)是批处理轴。

from typing import Union

class NotMapped: pass
not_mapped = NotMapped()

BatchAxis = Union[NotMapped, int]

class BatchTracer(Tracer):
  def __init__(self, trace, val, batch_dim: BatchAxis):
    self._trace = trace
    self.val = val
    self.batch_dim = batch_dim

  @property
  def aval(self):
    if self.batch_dim is not_mapped:
      return get_aval(self.val)
    else:
      return mapped_aval(self.batch_dim, get_aval(self.val))

  def full_lower(self):
    if self.batch_dim is not_mapped:
      return full_lower(self.val)
    else:
      return self

class BatchTrace(Trace):
  pure = lift = lambda self, val: BatchTracer(self, val, not_mapped)

  def process_primitive(self, primitive, tracers, params):
    vals_in, bdims_in = unzip2((t.val, t.batch_dim) for t in tracers)
    vmap_rule = vmap_rules[primitive]
    val_outs, bdim_outs = vmap_rule(self.axis_size, vals_in, bdims_in, **params)
    return [BatchTracer(self, x, bd) for x, bd in zip(val_outs, bdim_outs)]

  @property
  def axis_size(self):
    return self.main.global_data

vmap_rules = {}

在这里,我们实现了可选的 Tracer.full_lower 方法,该方法允许我们去掉不需要的批处理跟踪器,因为它并不表示一个批量值。

对于 BatchTrace,与 JVPTrace 类似,purelift 方法只是将一个值装入 BatchTracer 中,同时最小化上下文,在这种情况下,上下文是一个将哨兵值 not_mapped 作为批次维度的 batch_dim。请注意,我们使用 MainTrace 的解释器全局数据字段来存储批次轴的大小。

接下来,我们可以为每个原语定义批处理解释器规则:

from functools import partial

def binop_batching_rule(op, axis_size, vals_in, dims_in):
  (x, y), (x_bdim, y_bdim) = vals_in, dims_in
  if x_bdim != y_bdim:
    if x_bdim is not_mapped:
      x = move_batch_axis(axis_size, x_bdim, y_bdim, x)
      x_bdim = y_bdim
    else:
      y = move_batch_axis(axis_size, y_bdim, x_bdim, y)
  return [op(x, y)], [x_bdim]
vmap_rules[add_p] = partial(binop_batching_rule, add)
vmap_rules[mul_p] = partial(binop_batching_rule, mul)

def vectorized_unop_batching_rule(op, axis_size, vals_in, dims_in):
  (x,), (x_bdim,) = vals_in, dims_in
  return [op(x)], [x_bdim]
vmap_rules[sin_p] = partial(vectorized_unop_batching_rule, sin)
vmap_rules[cos_p] = partial(vectorized_unop_batching_rule, cos)
vmap_rules[neg_p] = partial(vectorized_unop_batching_rule, neg)

def reduce_sum_batching_rule(axis_size, vals_in, dims_in, *, axis):
  (x,), (x_bdim,) = vals_in, dims_in
  new_axis = tuple(ax + (x_bdim <= ax) for ax in axis)
  out_bdim = x_bdim - sum(ax < x_bdim for ax in axis)
  return [reduce_sum(x, new_axis)], [out_bdim]
vmap_rules[reduce_sum_p] = reduce_sum_batching_rule

最后,我们添加一个转换 API 来启动跟踪:

def vmap_flat(f, in_axes, *args):
  axis_size, = {x.shape[ax] for x, ax in zip(args, in_axes)
                if ax is not not_mapped}
  with new_main(BatchTrace, axis_size) as main:
    trace = BatchTrace(main)
    tracers_in = [BatchTracer(trace, x, ax) if ax is not None else x
                  for x, ax in zip(args, in_axes)]
    outs = f(*tracers_in)
    tracers_out = [full_raise(trace, out) for out in outs]
    vals_out, bdims_out = unzip2((t.val, t.batch_dim) for t in tracers_out)
  outs_transposed = [move_batch_axis(axis_size, bdim, 0, val_out)
                     for val_out, bdim in zip(vals_out, bdims_out)]
  return outs_transposed

def vmap(f, in_axes):
  def batched_f(*args):
    args_flat, in_tree = tree_flatten(args)
    in_axes_flat, in_tree2 = tree_flatten(in_axes)
    if in_tree != in_tree2: raise TypeError
    f_flat, out_tree = flatten_fun(f, in_tree)
    outs_flat = vmap_flat(f_flat, in_axes_flat, *args_flat)
    return tree_unflatten(out_tree(), outs_flat)
  return batched_f
def add_one_to_a_scalar(scalar):
  assert np.ndim(scalar) == 0
  return 1 + scalar

vector_in = np.arange(3.)
vector_out = vmap(add_one_to_a_scalar, (0,))(vector_in)

print(vector_in)
print(vector_out)
[0. 1. 2.]
[1. 2. 3.]
def jacfwd(f, x):
  pushfwd = lambda v: jvp(f, (x,), (v,))[1]
  vecs_in = np.eye(np.size(x)).reshape(np.shape(x) * 2)
  return vmap(pushfwd, (0,))(vecs_in)

def f(x):
  return sin(x)

jacfwd(f, np.arange(3.))
array([[ 1.        ,  0.        , -0.        ],
       [ 0.        ,  0.54030231, -0.        ],
       [ 0.        ,  0.        , -0.41614684]])

这就是关于jvpvmap的全部内容!

第二部分:Jaxprs#

下一个将要实现的变换是 jit,用于即时编译,以及 vjp,用于反向自动微分。(grad 只是 vjp 的一个小封装。) 虽然 jvpvmap 只需要每个 Tracer 携带一点额外的上下文,但对于 jitvjp,我们需要更丰富的上下文:我们需要表示_程序_。也就是说,我们需要 jaxprs!

Jaxprs 是 JAX 的内部程序中间表示。它们是显式类型的,函数式的,一阶的,采用 ANF 形式。我们需要 jit 的程序表示,因为 jit 的目的是将计算从 Python 中分离出来。对于我们想要分离的任何计算,我们需要能够将其表示为数据,并在追踪 Python 函数时逐步构建它。同样,vjp 需要一种方式表示反向传播的计算。我们对这两个需求使用相同的 jaxpr 程序表示。

(构建程序表示是最自由的追踪变换,因此除了处理本机 Python 控制流的问题外,任何变换都可以通过首先追踪到 jaxpr,然后解释 jaxpr 来实现。)

Jaxpr 数据结构#

jaxpr 术语语法大致为:

jaxpr ::=
  { lambda <binder> , ... .
    let <eqn>
        ...
    in ( <atom> , ... ) }

binder ::= <var>:<array_type>
var ::= a | b | c | ...
atom ::= <var> | <literal>
literal ::= <int32> | <int64> | <float32> | <float64>

eqn ::= <binder> , ... = <primitive> [ <params> ] <atom> , ...

类型的语法为:

jaxpr_type ::= [ <array_type> , ... ] -> [ <array_type> , ... ]
array_type ::= <dtype>[<shape>]
dtype ::= f32 | f64 | i32 | i64
shape ::= <int> , ...

我们如何将这些表示为 Python 数据结构?我们重用 ShapedArrays 来表示类型,并且我们可以使用几个 Python 结构来表示术语语法:

class Var:
  aval: ShapedArray
  def __init__(self, aval): self.aval = aval

class Lit:
  val: Any
  aval: ShapedArray

  def __init__(self, val):
    self.aval = aval = raise_to_shaped(get_aval(val))
    self.val = np.array(val, aval.dtype)

Atom = Union[Var, Lit]

class JaxprEqn(NamedTuple):
  primitive: Primitive
  inputs: list[Atom]
  params: dict[str, Any]
  out_binders: list[Var]

class Jaxpr(NamedTuple):
  in_binders: list[Var]
  eqns: list[JaxprEqn]
  outs: list[Atom]

  def __hash__(self): return id(self)
  __eq__ = op.is_

def raise_to_shaped(aval):
  return ShapedArray(aval.shape, aval.dtype)

对jaxpr进行类型检查涉及以下几个方面:确保没有未绑定的变量,确保每个变量只被绑定一次,以及确保对于每个等式,原始应用的类型与输出绑定器的类型相匹配。

class JaxprType(NamedTuple):
  in_types:  list[ShapedArray]
  out_types: list[ShapedArray]

  def __repr__(self):
    in_types = ', '.join(aval.str_short() for aval in self.in_types)
    out_types = ', '.join(aval.str_short() for aval in self.out_types)
    return f'({in_types}) -> ({out_types})'

def typecheck_jaxpr(jaxpr: Jaxpr) -> JaxprType:
  env: set[Var] = set()

  for v in jaxpr.in_binders:
    if v in env: raise TypeError
    env.add(v)

  for eqn in jaxpr.eqns:
    in_types = [typecheck_atom(env, x) for x in eqn.inputs]
    out_types = abstract_eval_rules[eqn.primitive](*in_types, **eqn.params)
    for out_binder, out_type in zip(eqn.out_binders, out_types):
      if not out_type == out_binder.aval: raise TypeError
    for out_binder in eqn.out_binders:
      if out_binder in env: raise TypeError
      env.add(out_binder)

  in_types = [v.aval for v in jaxpr.in_binders]
  out_types = [typecheck_atom(env, x) for x in jaxpr.outs]
  return JaxprType(in_types, out_types)

def typecheck_atom(env: set[Var], x: Atom) -> ShapedArray:
  if isinstance(x, Var):
    if x not in env: raise TypeError("unbound variable")
    return x.aval
  elif isinstance(x, Lit):
    return raise_to_shaped(get_aval(x.val))
  else:
    assert False

我们可以将jaxpr表示的函数应用于参数,使用一个简单的解释器。

def eval_jaxpr(jaxpr: Jaxpr, args: list[Any]) -> list[Any]:
  env: dict[Var, Any] = {}

  def read(x: Atom) -> Any:
    return env[x] if type(x) is Var else x.val

  def write(v: Var, val: Any) -> None:
    assert v not in env  # 单赋值
    env[v] = val

  map(write, jaxpr.in_binders, args)
  for eqn in jaxpr.eqns:
    in_vals = map(read, eqn.inputs)
    outs = bind(eqn.primitive, *in_vals, **eqn.params)
    map(write, eqn.out_binders, outs)
  return map(read, jaxpr.outs)

def jaxpr_as_fun(jaxpr: Jaxpr):
  return lambda *args: eval_jaxpr(jaxpr, args)

通过在解释器中使用bind,该解释器本身是可以被跟踪的。

构建jaxprs与追踪#

现在我们已经有了jaxprs作为一种数据结构,我们需要从追踪Python代码生成这些jaxprs的方法。一般来说,我们有两种追踪到jaxpr的变体;jit使用其中一种,而vjp使用另一种。我们将从jit使用的那种开始,这种方式也被控制流原语如lax.condlax.while_looplax.scan所使用。

def split_list(lst: list[Any], n: int) -> tuple[list[Any], list[Any]]:
  assert 0 <= n <= len(lst)
  return lst[:n], lst[n:]

def partition_list(bs: list[bool], l: list[Any]) -> tuple[list[Any], list[Any]]:
  assert len(bs) == len(l)
  lists = lst1, lst2 = [], []
  for b, x in zip(bs, l):
    lists[b].append(x)
  return lst1, lst2
# 注:JAX 中类似的类称为 'DynamicJaxprTracer'
class JaxprTracer(Tracer):
  __slots__ = ['aval']
  aval: ShapedArray

  def __init__(self, trace, aval):
    self._trace = trace
    self.aval = aval

# 注:JAX 中类似的类称为 'DynamicJaxprTrace'
class JaxprTrace(Trace):
  def new_arg(self, aval: ShapedArray) -> JaxprTracer:
    aval = raise_to_shaped(aval)
    tracer = self.builder.new_tracer(self, aval)
    self.builder.tracer_to_var[id(tracer)] = Var(aval)
    return tracer

  def get_or_make_const_tracer(self, val: Any) -> JaxprTracer:
    tracer = self.builder.const_tracers.get(id(val))
    if tracer is None:
      tracer = self.builder.new_tracer(self, raise_to_shaped(get_aval(val)))
      self.builder.add_const(tracer, val)
    return tracer
  pure = lift = get_or_make_const_tracer

  def process_primitive(self, primitive, tracers, params):
    avals_in = [t.aval for t in tracers]
    avals_out = abstract_eval_rules[primitive](*avals_in, **params)
    out_tracers = [self.builder.new_tracer(self, a) for a in avals_out]
    inputs = [self.builder.getvar(t) for t in tracers]
    outvars = [self.builder.add_var(t) for t in out_tracers]
    self.builder.add_eqn(JaxprEqn(primitive, inputs, params, outvars))
    return out_tracers

  @property
  def builder(self):
    return self.main.global_data

# 注意:在 JAX 中,我们改为将抽象评估规则附加到 Primitive 实例上。
abstract_eval_rules = {}

请注意,我们将一个构建器对象作为解释器全局数据保留,该对象在构建jaxpr时跟踪变量、常量和方程。

class JaxprBuilder:
  eqns: list[JaxprEqn]
  tracer_to_var: dict[int, Var]
  const_tracers: dict[int, JaxprTracer]
  constvals: dict[Var, Any]
  tracers: list[JaxprTracer]

  def __init__(self):
    self.eqns = []
    self.tracer_to_var = {}
    self.const_tracers = {}
    self.constvals = {}
    self.tracers = []

  def new_tracer(self, trace: JaxprTrace, aval: ShapedArray) -> JaxprTracer:
    tracer = JaxprTracer(trace, aval)
    self.tracers.append(tracer)
    return tracer

  def add_eqn(self, eqn: JaxprEqn) -> None:
    self.eqns.append(eqn)

  def add_var(self, tracer: JaxprTracer) -> Var:
    assert id(tracer) not in self.tracer_to_var
    var = self.tracer_to_var[id(tracer)] = Var(tracer.aval)
    return var

  def getvar(self, tracer: JaxprTracer) -> Var:
    var = self.tracer_to_var.get(id(tracer))
    assert var is not None
    return var

  def add_const(self, tracer: JaxprTracer, val: Any) -> Var:
    var = self.add_var(tracer)
    self.const_tracers[id(val)] = tracer
    self.constvals[var] = val
    return var

  def build(self, in_tracers: list[JaxprTracer], out_tracers: list[JaxprTracer]
            ) -> tuple[Jaxpr, list[Any]]:
    constvars, constvals = unzip2(self.constvals.items())
    t2v = lambda t: self.tracer_to_var[id(t)]
    in_binders = constvars + [t2v(t) for t in in_tracers]
    out_vars = [t2v(t) for t in out_tracers]
    jaxpr = Jaxpr(in_binders, self.eqns, out_vars)
    typecheck_jaxpr(jaxpr)
    jaxpr, constvals = _inline_literals(jaxpr, constvals)
    return jaxpr, constvals
def _inline_literals(jaxpr: Jaxpr, consts: list[Any]) -> tuple[Jaxpr, list[Any]]:
  const_binders, other_binders = split_list(jaxpr.in_binders, len(consts))
  scalars = [type(x) in jax_types and not get_aval(x).shape for x in consts]
  new_const_binders, lit_binders = partition_list(scalars, const_binders)
  new_consts, lit_vals = partition_list(scalars, consts)
  literals = dict(zip(lit_binders, map(Lit, lit_vals)))
  new_eqns = [JaxprEqn(eqn.primitive, [literals.get(x, x) for x in eqn.inputs],
                       eqn.params, eqn.out_binders) for eqn in jaxpr.eqns]
  new_outs = [literals.get(x, x) for x in jaxpr.outs]
  new_jaxpr = Jaxpr(new_const_binders + other_binders, new_eqns, new_outs)
  typecheck_jaxpr(new_jaxpr)
  return new_jaxpr, new_consts

我们需要的 JaxprTrace.process_primitive 规则本质上是原始应用的类型规则:给定原始函数、其参数以及输入的类型,该规则必须生成输出的类型,然后与输出的 JaxprTracer 一起打包。我们可以使用抽象评估规则来实现这一目的,即使它们可能更为通用(因为抽象评估规则必须接受 ConcreteArray 输入,并且因为它们只需要返回可能输出集合的上界,因此它们也可以生成 ConcreteArray 输出)。我们将重用这些抽象评估规则用于其他生成 jaxpr 的跟踪机制,在那里潜在的额外通用性是有用的。

def binop_abstract_eval(x: ShapedArray, y: ShapedArray) -> list[ShapedArray]:
  if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
    raise TypeError
  if raise_to_shaped(x) != raise_to_shaped(y): raise TypeError
  return [ShapedArray(x.shape, x.dtype)]

abstract_eval_rules[add_p] = binop_abstract_eval
abstract_eval_rules[mul_p] = binop_abstract_eval

def compare_abstract_eval(x: ShapedArray, y: ShapedArray) -> list[ShapedArray]:
  if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):
    raise TypeError
  if x.shape != y.shape: raise TypeError
  return [ShapedArray(x.shape, np.dtype('bool'))]
abstract_eval_rules[greater_p] = compare_abstract_eval
abstract_eval_rules[less_p] = compare_abstract_eval

def vectorized_unop_abstract_eval(x: ShapedArray) -> list[ShapedArray]:
  return [ShapedArray(x.shape, x.dtype)]

abstract_eval_rules[sin_p] = vectorized_unop_abstract_eval
abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval
abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval

def reduce_sum_abstract_eval(x: ShapedArray, *, axis: tuple[int, ...]
                             ) -> list[ShapedArray]:
  axis_ = set(axis)
  new_shape = [d for i, d in enumerate(x.shape) if i not in axis_]
  return [ShapedArray(tuple(new_shape), x.dtype)]
abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval

def broadcast_abstract_eval(x: ShapedArray, *, shape: Sequence[int],
                            axes: Sequence[int]) -> list[ShapedArray]:
  return [ShapedArray(tuple(shape), x.dtype)]
abstract_eval_rules[broadcast_p] = broadcast_abstract_eval

为了检查我们对jaxprs的实现,我们可以添加一个make_jaxpr变换和一个美化打印程序:

from functools import lru_cache

@lru_cache  # ShapedArrays 是可哈希的
def make_jaxpr_v1(f, *avals_in):
  avals_in, in_tree = tree_flatten(avals_in)
  f, out_tree = flatten_fun(f, in_tree)

  builder = JaxprBuilder()
  with new_main(JaxprTrace, builder) as main:
    trace = JaxprTrace(main)
    tracers_in = [trace.new_arg(aval) for aval in avals_in]
    outs = f(*tracers_in)
    tracers_out = [full_raise(trace, out) for out in outs]
    jaxpr, consts = builder.build(tracers_in, tracers_out)
  return jaxpr, consts, out_tree()
Hide code cell source
from collections import defaultdict
import string

class PPrint:
  lines: list[tuple[int, str]]

  def __init__(self, lines):
    self.lines = lines

  def indent(self, indent: int) -> 'PPrint':
    return PPrint([(indent + orig_indent, s) for orig_indent, s in self.lines])

  def __add__(self, rhs: 'PPrint') -> 'PPrint':
    return PPrint(self.lines + rhs.lines)

  def __rshift__(self, rhs: 'PPrint') -> 'PPrint':
    if not rhs.lines: return self
    if not self.lines: return rhs
    indent, s = self.lines[-1]
    indented_block = rhs.indent(indent + len(s))
    common_line = s + ' ' * rhs.lines[0][0] + rhs.lines[0][1]
    return PPrint(self.lines[:-1]
                  + [(indent, common_line)]
                  + indented_block.lines[1:])

  def __str__(self) -> str:
    return '\n'.join(' ' * indent + s for indent, s in self.lines)

def pp(s: Any) -> PPrint:
  return PPrint([(0, line) for line in str(s).splitlines()])

def vcat(ps: list[PPrint]) -> PPrint:
  return sum(ps, pp(''))

def pp_jaxpr(jaxpr: Jaxpr) -> PPrint:
  namegen = (''.join(s) for r in it.count(1)
             for s in it.permutations(string.ascii_lowercase, r))
  names = defaultdict(lambda: next(namegen))
  in_binders = ', '.join(var_str(names, x) for x in jaxpr.in_binders)
  eqns = vcat([pp_eqn(names, e) for e in jaxpr.eqns])
  outs = ', '.join(names[v] if isinstance(v, Var) else str(v.val)
                   for v in jaxpr.outs)
  return (pp(f'{{ lambda {in_binders} .') +
          ((pp('let ') >> eqns) + pp(f'in ( {outs} ) }}')).indent(2))

def var_str(names: defaultdict[Var, str], v: Var) -> str:
  return f'{names[v]}:{v.aval.str_short()}'

def pp_eqn(names: defaultdict[Var, str], eqn: JaxprEqn) -> PPrint:
  rule = pp_rules.get(eqn.primitive)
  if rule:
    return rule(names, eqn)
  else:
    lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
    rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >>
           pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
                       for x in eqn.inputs)))
    return lhs >> pp(' = ') >> rhs

def pp_params(params: dict[str, Any]) -> PPrint:
  items = sorted(params.items())
  if items:
    return pp(' [ ') >> vcat([pp(f'{k}={v}') for k, v in items]) >> pp(' ] ')
  else:
    return pp(' ')

Jaxpr.__repr__ = lambda self: str(pp_jaxpr(self))
pp_rules: dict[Primitive, Callable[..., PPrint]] = {}
jaxpr, consts, _ = make_jaxpr_v1(lambda x: 2. * x, raise_to_shaped(get_aval(3.)))
print(jaxpr)
print(typecheck_jaxpr(jaxpr))
{ lambda a:float64[] .
  let b:float64[] = mul 2.0 a
  in ( b ) }
(float64[]) -> (float64[])

但这里有一个限制:由于find_top_trace根据数据依赖的方式进行操作,make_jaxpr_v1无法将其所给定的Python调用中执行的所有基本操作进行分阶段处理。例如:

jaxpr, consts, _ = make_jaxpr_v1(lambda: mul(2., 2.))
print(jaxpr)
{ lambda  .
  let 
  in ( 4.0 ) }

这正是omnistaging所解决的问题。我们希望确保 make_jaxpr 启动的 JaxprTrace 始终被应用,无论传递给 bind 的任何输入是否在相应的 JaxprTracer 实例中被封装。我们可以通过使用第1部分中定义的 dynamic_trace 全局变量来实现这一点:

@contextmanager
def new_dynamic(main: MainTrace):
  global dynamic_trace
  prev_dynamic_trace, dynamic_trace = dynamic_trace, main
  try:
    yield
  finally:
    dynamic_trace = prev_dynamic_trace

@lru_cache
def make_jaxpr(f: Callable, *avals_in: ShapedArray,
               ) -> tuple[Jaxpr, list[Any], PyTreeDef]:
  avals_in, in_tree = tree_flatten(avals_in)
  f, out_tree = flatten_fun(f, in_tree)

  builder = JaxprBuilder()
  with new_main(JaxprTrace, builder) as main:
    with new_dynamic(main):
      trace = JaxprTrace(main)
      tracers_in = [trace.new_arg(aval) for aval in avals_in]
      outs = f(*tracers_in)
      tracers_out = [full_raise(trace, out) for out in outs]
      jaxpr, consts = builder.build(tracers_in, tracers_out)
  return jaxpr, consts, out_tree()

jaxpr, consts, _ = make_jaxpr(lambda: mul(2., 2.))
print(jaxpr)
{ lambda  .
  let a:float64[] = mul 2.0 2.0
  in ( a ) }

使用 dynamic_trace 这种方式在概念上与存储当前解释器栈并在底部启动一个新的 JaxprTrace 是相同的。也就是说,在 dynamic_trace 之下的栈中不会应用更低层的解释器(因为 JaxprTrace.process_primitive 并不调用 bind),不过如果被追踪到 jaxpr 的 Python 可调用对象本身使用了变换,那么这些变换可以被推入到 JaxprTrace 之上的解释器栈中。但是,暂时存储解释器栈会打破系统状态。dynamic_trace 标签实现了相同的目标,同时保持了系统状态的简单性。

这就是jaxprs的全部内容!有了jaxprs,我们可以实现剩余的主要JAX功能。

第三部分:jit,简化版#

虽然 jit 具有类似于变换的 API,因为它接受一个 Python 可调用对象作为参数,但在内部它实际上是一个高阶基元,而不是一个变换。当一个基元通过一个函数进行参数化时,它被称为 高阶

按需处理(“最终风格”)与分阶段处理(“初始风格”)#

对于如何处理高阶原语有两种选择。每种选择需要不同的追踪方法,并带来不同的权衡:

  1. 按需处理,其中 bind 将一个 Python 可调用对象作为参数。 我们推迟形成 jaxpr,直到尽可能晚,也就是说,直到我们在解释器栈底部运行最终解释器。这样我们就可以在解释器栈底部插入一个 JaxprTrace,因此可以进行分阶段处理而不是直接执行所有原语操作。使用这种方法,栈中的变换会在我们像往常一样执行 Python 可调用对象时应用。这个方法的实现可能非常棘手,但它尽可能地通用,因为它允许高阶原语不提升其参数的抽象水平,从而允许数据相关的 Python 控制流。我们将这种方法称为使用“最终风格的高阶原语”,采用我们到目前为止使用的在追踪时排放的“最终风格变换”。

  2. 分阶段处理,其中 bind 将一个 jaxpr 作为参数。 在我们调用 bind 之前,在原语封装中,我们可以直接使用 make_jaxpr 提前形成一个 jaxpr,并完全处理 Python 可调用对象。在这种情况下,make_jaxpr 将其 JaxprTrace 放在解释器栈的顶部,并且栈下方的没有变换(这些变换可能通过闭合的 Tracers 进入)不会应用于我们追踪的 Python 可调用对象。 (在 Python 可调用对象内部应用的变换会像往常一样被应用,添加到 JaxprTrace 之上的栈中。)相反,栈下方的变换稍后会应用于调用原语,而调用原语的规则必须转变 jaxpr 本身。因为我们提前追踪到 jaxpr,所以这种方法无法支持数据相关的 Python 控制流,但实现起来更为直接。我们将这种类型的高阶原语称为“初始风格的高阶原语”,并且说它的 jaxpr 处理变换规则是“初始风格变换规则”。

后者的方法适用于 jit,因为我们不需要支持用户提供的 Python 可调用对象中的数据相关的 Python 控制流,因为 jit 的整个目的就是将计算从 Python 中分阶段出来,由 XLA 执行。(相反,custom_jvp 是一种高阶原语,我们希望支持数据相关的 Python 控制流。)

历史上,我们在阅读 类型标记无终极解释器 论文后开始使用“初始风格”和“最终风格”这个术语,并戏称 JAX 为“无类型标记最终解释器”的实现。我们并不声称这些术语背后有任何深层含义(或理解其含义);我们笼统地使用“初始风格”来表示“构建抽象语法树然后变换”,而使用“最终风格”来表示“在追踪时变换”。但这只是模糊而又顽固的行话。

使用初始风格的方法,用户界面的 jit 包装器如下:

def jit(f):
  def f_jitted(*args):
    avals_in = [raise_to_shaped(get_aval(x)) for x in args]
    jaxpr, consts, out_tree = make_jaxpr(f, *avals_in)
    outs = bind(xla_call_p, *consts, *args, jaxpr=jaxpr, num_consts=len(consts))
    return tree_unflatten(out_tree, outs)
  return f_jitted

xla_call_p = Primitive('xla_call')

对于任何新的原语,我们需要制定转换规则,首先是它的评估规则。当我们评估xla_call原语的应用时,我们希望将计算转移到XLA。这涉及将jaxpr转换为XLA HLO程序,将参数值传输到XLA设备上,执行XLA程序,并将结果传回。我们将缓存XLA HLO编译,以便每个jitted函数仅需针对每个参数形状和数据类型签名执行一次。

首先,一些工具函数。

class IDHashable:
  val: Any

  def __init__(self, val):
    self.val = val

  def __hash__(self) -> int:
    return id(self.val)

  def __eq__(self, other):
    return type(other) is IDHashable and id(self.val) == id(other.val)

接下来,我们将定义 xla_call 的评估规则:

from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
xe = xc._xla
xops = xc._xla.ops

def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int):
  consts, args = args[:num_consts], args[num_consts:]
  hashable_consts = tuple(map(IDHashable, consts))
  execute = xla_callable(IDHashable(jaxpr), hashable_consts)
  return execute(*args)
impl_rules[xla_call_p] = xla_call_impl

@lru_cache
def xla_callable(hashable_jaxpr: IDHashable,
                 hashable_consts: tuple[IDHashable, ...]):
  jaxpr: Jaxpr = hashable_jaxpr.val
  typecheck_jaxpr(jaxpr)
  consts = [x.val for x in hashable_consts]
  in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]]
  c = xc.XlaBuilder('xla_call')
  xla_consts = _xla_consts(c, consts)
  xla_params = _xla_params(c, in_avals)
  outs = jaxpr_subcomp(c, jaxpr, xla_consts + xla_params)
  out = xops.Tuple(c, outs)
  compiled = xb.get_backend(None).compile(
    xc._xla.mlir.xla_computation_to_mlir_module(c.build(out)))
  return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs])

def _xla_consts(c: xe.XlaBuilder, consts: list[Any]) -> list[xe.XlaOp]:
  unique_consts = {id(cnst): cnst for cnst in consts}
  xla_consts = {
      id_: xops.ConstantLiteral(c, cnst) for id_, cnst in unique_consts.items()}
  return [xla_consts[id(cnst)] for cnst in consts]

def _xla_params(c: xe.XlaBuilder, avals_in: list[ShapedArray]) -> list[xe.XlaOp]:
  return [xops.Parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)]

def _xla_shape(aval: ShapedArray) -> xe.Shape:
  return xc.Shape.array_shape(xc.dtype_to_etype(aval.dtype), aval.shape)

主要的动作在于 xla_callable,它利用 jaxpr_subcomp 将一个 jaxpr 编译为一个 XLA HLO 程序,然后返回一个可调用对象来执行编译后的程序:

def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: list[xe.XlaOp]
                  ) -> list[xe.XlaOp]:
  env: dict[Var, xe.XlaOp] = {}

  def read(x: Atom) -> xe.XlaOp:
    return env[x] if type(x) is Var else xops.Constant(c, np.asarray(x.val))

  def write(v: Var, val: xe.XlaOp) -> None:
    env[v] = val

  map(write, jaxpr.in_binders, args)
  for eqn in jaxpr.eqns:
    in_avals = [x.aval for x in eqn.inputs]
    in_vals = map(read, eqn.inputs)
    rule = xla_translations[eqn.primitive]
    out_vals = rule(c, in_avals, in_vals, **eqn.params)
    map(write, eqn.out_binders, out_vals)
  return map(read, jaxpr.outs)

def execute_compiled(compiled, out_avals, *args):
  input_bufs = [input_handlers[type(x)](x) for x in args]
  out_bufs = compiled.execute(input_bufs)
  return [handle_result(aval, buf) for aval, buf in zip(out_avals, out_bufs)]

default_input_handler = xb.get_backend(None).buffer_from_pyval
input_handlers = {ty: default_input_handler for ty in
                  [bool, int, float, np.ndarray, np.float64, np.float32]}

def handle_result(aval: ShapedArray, buf):
  del aval  # 暂未使用
  return np.asarray(buf)

xla_translations = {}

注意到 jaxpr_subcomp 的结构像一个简单的解释器。这是一种常见的模式:我们处理 jaxprs 的方式通常是通过一个解释器。和任何解释器一样,我们需要为每一个原始操作定义一个解释规则:

def direct_translation(op, c, in_avals, in_vals):
  del c, in_avals
  return [op(*in_vals)]

xla_translations[add_p] = partial(direct_translation, xops.Add)
xla_translations[mul_p] = partial(direct_translation, xops.Mul)
xla_translations[neg_p] = partial(direct_translation, xops.Neg)
xla_translations[sin_p] = partial(direct_translation, xops.Sin)
xla_translations[cos_p] = partial(direct_translation, xops.Cos)
xla_translations[greater_p] = partial(direct_translation, xops.Gt)
xla_translations[less_p] = partial(direct_translation, xops.Lt)

def reduce_sum_translation(c, in_avals, in_vals, *, axis):
  (x_aval,), (x,) = in_avals, in_vals
  zero = xops.ConstantLiteral(c, np.array(0, x_aval.dtype))
  subc = xc.XlaBuilder('add')
  shape = _xla_shape(ShapedArray((), x_aval.dtype))
  xops.Add(xops.Parameter(subc, 0, shape), xops.Parameter(subc, 1, shape))
  return [xops.Reduce(c, [x], [zero], subc.build(), axis)]
xla_translations[reduce_sum_p] = reduce_sum_translation

def broadcast_translation(c, in_avals, in_vals, *, shape, axes):
  x, = in_vals
  dims_complement = [i for i in range(len(shape)) if i not in axes]
  return [xops.BroadcastInDim(x, shape, dims_complement)]
xla_translations[broadcast_p] = broadcast_translation

通过这样,我们现在可以使用 jit 来分阶段、编译和执行带有 XLA 的程序!

@jit
def f(x, y):
  print('tracing!')
  return sin(x) * cos(y)
z = f(3., 4.)  # 'tracing!' prints the first time
print(z)
tracing!
-0.09224219304455371
z = f(4., 5.)  # 'tracing!' doesn't print, compilation cache hit!
print(z)
-0.21467624978306998
@jit
def f(x):
  return reduce_sum(x, axis=0)

print(f(np.array([1., 2., 3.])))
6.0
def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z

def deriv(f):
  return lambda x: jvp(f, (x,), (1.,))[1]

print(    deriv(deriv(f))(3.))
print(jit(deriv(deriv(f)))(3.))
0.2822400161197344
0.2822400161197344

而不是先实现 jit 来跟踪到 jaxpr,然后将 jaxpr 降低到 XLA HLO,可能看起来我们可以跳过 jaxpr 步骤,直接在跟踪时降到 HLO。也就是说,也许我们可以改为用一个 TraceTracer 实现 jit,在每个原始绑定时逐步添加到 XLA HLO 图中。现在这是正确的,但在我们引入编译的 SPMD 计算时将无法实现,因为在编译程序之前,我们必须知道所需的副本数量。

我们还没有为 xla_call_p 定义任何转换规则,除了它的评估规则。也就是说,我们还无法进行 vmap-of-jitjvp-of-jit,甚至不能进行 jit-of-jit。相反,jit 必须处于“顶层”。让我们来解决这个问题!

def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):
  del num_consts  # 未使用的
  new_jaxpr, new_consts = jvp_jaxpr(jaxpr)
  outs = bind(xla_call_p, *new_consts, *primals, *tangents, jaxpr=new_jaxpr,
              num_consts=len(new_consts))
  n = len(outs) // 2
  primals_out, tangents_out = outs[:n], outs[n:]
  return primals_out, tangents_out
jvp_rules[xla_call_p] = xla_call_jvp_rule

@lru_cache
def jvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]:
  def jvp_traceable(*primals_and_tangents):
    n = len(primals_and_tangents) // 2
    primals, tangents = primals_and_tangents[:n], primals_and_tangents[n:]
    return jvp(jaxpr_as_fun(jaxpr), primals, tangents)

  in_avals = [v.aval for v in jaxpr.in_binders]
  new_jaxpr, new_consts, _ = make_jaxpr(jvp_traceable, *in_avals, *in_avals)
  return new_jaxpr, new_consts
def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):
  del num_consts  # 未使用的
  new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, tuple(dims_in))
  outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr,
              num_consts=len(new_consts))
  return outs, [0] * len(outs)
vmap_rules[xla_call_p] = xla_call_vmap_rule

@lru_cache
def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: tuple[BatchAxis, ...]
               ) -> tuple[Jaxpr, list[Any]]:
  vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))
  in_avals = [unmapped_aval(axis_size, d, v.aval)
              for v, d in zip(jaxpr.in_binders, bdims_in)]
  new_jaxpr, new_consts, _ = make_jaxpr(vmap_traceable, *in_avals)
  return new_jaxpr, new_consts

def unmapped_aval(axis_size: int, batch_dim: BatchAxis, aval: ShapedArray
                  ) -> ShapedArray:
  if batch_dim is not_mapped:
    return aval
  else:
    shape = list(aval.shape)
    shape.insert(batch_dim, axis_size)
    return ShapedArray(tuple(shape), aval.dtype)
def xla_call_abstract_eval_rule(*in_types, jaxpr, num_consts):
  del num_consts  # 未使用的
  jaxpr_type = typecheck_jaxpr(jaxpr)
  if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):
    raise TypeError
  return jaxpr_type.out_types
abstract_eval_rules[xla_call_p] = xla_call_abstract_eval_rule

def xla_call_translation(c, in_avals, in_vals, *, jaxpr, num_consts):
  del num_consts  # 仅在顶级使用。
  # 直接调用 jaxpr_subcomp 会导致内联。我们改为生成一个 Call HLO。
  subc = xc.XlaBuilder('inner xla_call')
  xla_params = _xla_params(subc, in_avals)
  outs = jaxpr_subcomp(subc, jaxpr, xla_params)
  subc = subc.build(xops.Tuple(subc, outs))
  return destructure_tuple(c, xops.Call(c, subc, in_vals))
xla_translations[xla_call_p] = xla_call_translation

def destructure_tuple(c, tup):
  num_elements = len(c.get_shape(tup).tuple_shapes())
  return [xops.GetTupleElement(tup, i) for i in range(num_elements)]
@jit
def f(x):
  print('tracing!')
  y = sin(x) * 2.
  z = - y + x
  return z

x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
tracing!
2.7177599838802657
2.979984993200891
y, ydot = jvp(f, (x,), (xdot,))  # 'tracing!' not printed
ys = vmap(f, (0,))(np.arange(3.))
print(ys)
[ 0.         -0.68294197  0.18140515]

缺少的一部分是数组的设备内存持久性。也就是说,我们已经定义了 handle_result 将结果作为 NumPy 数组传回 CPU 内存,但通常更可取的是避免传输结果,以便在下一次操作时再转移它们。我们可以通过引入一个 Array 类来做到这一点,该类可以封装 XLA 缓冲区,并且在其他方面模拟 numpy.ndarray 的行为:

def handle_result(aval: ShapedArray, buf):  # noqa:F811
  return Array(aval, buf)

class Array:
  buf: Any
  aval: ShapedArray

  def __init__(self, aval, buf):
    self.aval = aval
    self.buf = buf

  dtype = property(lambda self: self.aval.dtype)
  shape = property(lambda self: self.aval.shape)
  ndim  = property(lambda self: self.aval.ndim)

  def __array__(self): return np.asarray(self.buf)
  def __repr__(self):  return repr(np.asarray(self.buf))
  def __str__(self):   return str(np.asarray(self.buf))

  _neg = staticmethod(neg)
  _add = staticmethod(add)
  _radd = staticmethod(add)
  _mul = staticmethod(mul)
  _rmul = staticmethod(mul)
  _gt = staticmethod(greater)
  _lt = staticmethod(less)
input_handlers[Array] = lambda x: x.buf

jax_types.add(Array)
@jit
def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z

x, xdot = 3., 1.
y, ydot = jvp(f, (x,), (xdot,))
print(y)
print(ydot)
2.7177599838802657
2.979984993200891
Hide code cell source
def pprint_xla_call(names: defaultdict[Var, str], eqn: JaxprEqn) -> PPrint:
  lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
  params_without_jaxpr = {k:v for k, v in eqn.params.items() if k != 'jaxpr'}
  rhs = (pp(eqn.primitive.name) >> pp_params(params_without_jaxpr) >>
         pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
                     for x in eqn.inputs)))
  return vcat([lhs >> pp(' = ') >> rhs,
               pp_jaxpr(eqn.params['jaxpr']).indent(2)])
pp_rules[xla_call_p] = pprint_xla_call

第四部分:linearizevjp (以及 grad!)#

linearizevjp 自动微分函数是基于 jvp 构建的,但也涉及 jaxprs。这是因为这两者都涉及到计算的分阶段或延迟。

linearize#

linearize 的情况下,我们想要对 jvp 计算的线性部分进行分阶段处理。也就是说,从 Haskell风格的类型签名来看,如果我们有 jvp : (a -> b) -> (a, T a) -> (b, T b), 那么我们写 linearize : (a -> b) -> a -> (b, T a -o T b),这里的 T a 表示“a 的切线类型”,而使用“棉花糖” -o 而不是箭头 -> 来表示一个 线性 函数。我们也将 linearize 的语义定义为基于 jvp

y, f_lin = linearize(f, x)
y_dot = f_lin(x_dot)

对于 (y, y_dot) 得到的结果与下面的表示相同:

y, y_dot = jvp(f, (x,), (x_dot,))

其中 f_lin 的应用不会重新进行任何线性化工作。 我们将延迟的线性部分 f_lin : T a -o T b 表示为 jaxpr。

顺便提一句,现在我们有了线性箭头 -o,我们可以为 jvp 提供一个稍微更有信息量的类型:

jvp : (a -> b) -> (UnrestrictedUse a, T a) -o (UnrestrictedUse b, T b)

这里我们写 UnrestrictedUse 只是为了表明我们有一对特殊的元素,其中第一个元素可以以不受限制(非线性)的方式使用。 结合线性箭头,这种表示方法仅用于表达函数 jvp f 以非线性方式使用其第一个输入,而以线性方式使用其第二个输入,产生一个对应的非线性输出(可以以非线性方式使用)以及一个线性输出。这种更精细的类型签名编码了 jvp f 中的数据依赖关系,这对于部分求值是有用的。

为了从 JVP 构建 f_lin jaxpr,我们需要执行部分求值: 在跟踪时,我们评估所有的原始值,但将切线计算分阶段到一个 jaxpr 中。这是我们的第二种构建 jaxprs 的方式。但在 make_jaxpr 和其底层的 JaxprTrace / JaxprTracer 解释器旨在分阶段处理每一个原语绑定的情况下,这种第二种方法仅对那些与切线输入有数据依赖的原语绑定进行分阶段处理。

首先,一些工具:

def split_half(lst: list[Any]) -> tuple[list[Any], list[Any]]:
  assert not len(lst) % 2
  return split_list(lst, len(lst) // 2)

def merge_lists(which: list[bool], l1: list[Any], l2: list[Any]) -> list[Any]:
  l1, l2 = iter(l1), iter(l2)
  out = [next(l2) if b else next(l1) for b in which]
  assert next(l1, None) is next(l2, None) is None
  return out

接下来,我们将通过将 jvp 与一种通用的局部评估转换结合在一起来编写 linearize,该转换将在下文中添加:

def linearize_flat(f, *primals_in):
  pvals_in = ([PartialVal.known(x) for x in primals_in] +
              [PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])
  def f_jvp(*primals_tangents_in):
    primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))
    return [*primals_out, *tangents_out]
  jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in)
  primal_pvals, _ = split_half(pvals_out)
  assert all(pval.is_known for pval in primal_pvals)
  primals_out = [pval.const for pval in primal_pvals]
  f_lin = lambda *tangents: eval_jaxpr(jaxpr, [*consts, *tangents])
  return primals_out, f_lin

def linearize(f, *primals_in):
  primals_in_flat, in_tree = tree_flatten(primals_in)
  f, out_tree = flatten_fun(f, in_tree)
  primals_out_flat, f_lin_flat = linearize_flat(f, *primals_in_flat)
  primals_out = tree_unflatten(out_tree(), primals_out_flat)

  def f_lin(*tangents_in):
    tangents_in_flat, in_tree2 = tree_flatten(tangents_in)
    if in_tree != in_tree2: raise TypeError
    tangents_out_flat = f_lin_flat(*tangents_in_flat)
    return tree_unflatten(out_tree(), tangents_out_flat)

  return primals_out, f_lin

def vspace(aval: ShapedArray) -> ShapedArray:
  return raise_to_shaped(aval)  # 待办事项:处理整数?

现在我们转向一般的部分求值转换。我们的目标是接受一个 Python 可调用对象和一个输入列表,其中一些是已知的,另一些是未知的,并产生 (1) 可以从已知输入计算出的所有输出,以及 (2) 一个 jaxpr,表示 Python 可调用对象的计算部分,该部分只能在剩余输入已知之后执行。

这个转换在类型签名中很难总结。如果我们假设输入函数的类型签名为 (a1, a2) -> (b1, b2),其中 a1a2 分别表示已知和未知的输入,并且 b1 仅对 a1 有数据依赖,而 b2a2 有某些数据依赖,那么我们可以写出以下内容:

partial_eval : ((a1, a2) -> (b1, b2)) -> a1 -> exists r. (b1, r, (r, a2) -> b2)

用语言描述,给定类型 a1 的输入值,partial_eval 生成类型 b1 的输出以及代表完成第二阶段计算所需中间值的存在量化类型 r 的“残余”值。它还生成一个类型为 (r, a2) -> b2 的函数,该函数接受残余值以及剩余输入并生成剩余输出。

我们喜欢把部分求值视为将一个计算“解压”成两个。例如,考虑这个 jaxpr:

{ lambda a:float64[] .
  let b:float64[] = sin a
      c:float64[] = neg b
  in ( c ) }

JVP 的 jaxpr 看起来是这样的:

{ lambda a:float64[] b:float64[] .
  let c:float64[] = sin a
      d:float64[] = cos a
      e:float64[] = mul d b
      f:float64[] = neg c
      g:float64[] = neg e
  in ( f, g ) }

如果我们想象将部分求值应用于这个 jaxpr,已知第一个输入而未知第二个输入,我们最终会将 JVP jaxpr “解压”成原始和切线 jaxpr:

{ lambda a:float64[] .
  let c:float64[] = sin a
      d:float64[] = cos a
      f:float64[] = neg c
  in ( f, d ) }
{ lambda d:float64[] b:float64[] .
  let e:float64[] = mul d b
      g:float64[] = neg e
  in ( g ) }

第二个 jaxpr 表示我们希望从 linearize 获取的线性计算。

然而,与这个 jaxpr 示例不同,我们希望在评估输入 Python 可调用对象时发生对已知值的计算。也就是说,我们希望仅形成一个 jaxpr,对于那些由于依赖于未知输入而必须延迟的操作。在自动微分的上下文中,这一特性最终使我们能够处理像 grad(lambda x: x**2 if x > 0 else 0.) 这样的函数。Python 控制流之所以有效,是因为部分求值将原始计算保持在 Python 中。因此,我们的 TraceTracer 子类必须动态地理清哪些可以被评估,哪些必须被分阶段到 jaxpr 中。

首先,我们从一个 PartialVal 类开始,它表示一个可以是已知或未知的值:

class PartialVal(NamedTuple):
  aval: ShapedArray
  const: Any | None

  @classmethod
  def known(cls, val: Any):
    return PartialVal(get_aval(val), val)

  @classmethod
  def unknown(cls, aval: ShapedArray):
    return PartialVal(aval, None)

  is_known   = property(lambda self: self.const is not None)
  is_unknown = property(lambda self: self.const is     None)

部分求值将接受一个表示输入的 PartialVal 列表,并返回一个 PartialVal 输出列表,以及一个表示延迟计算的 jaxpr:

def partial_eval_flat(f: Callable, pvals_in: list[PartialVal]
                      ) -> tuple[Jaxpr, list[PartialVal], list[Any]]:
  with new_main(PartialEvalTrace) as main:
    trace = PartialEvalTrace(main)
    tracers_in = [trace.new_arg(pval) for pval in pvals_in]
    outs = f(*tracers_in)
    tracers_out = [full_raise(trace, out) for out in outs]
    pvals_out = [t.pval for t in tracers_out]
    unk_tracers_in  = [t for t in tracers_in  if t.pval.is_unknown]
    unk_tracers_out = [t for t in tracers_out if t.pval.is_unknown]
    jaxpr, consts = tracers_to_jaxpr(unk_tracers_in, unk_tracers_out)
  return jaxpr, pvals_out, consts

接下来,我们需要实现 PartialEvalTrace 和它的 PartialEvalTracer。这个解释器将实时构建一个 jaxpr,同时追踪数据依赖关系。为此,它在 PartialEvalTracer 节点(表示分阶段的值)和 JaxprRecipe 节点(表示如何根据其他值计算某些值的公式)之间构建一个二分有向无环图(DAG)。其中一种配方是 JaxprEqnRecipe,对应于 JaxprEqn 的原始应用,但我们还有常量和 lambda 绑定的配方类型:

from weakref import ref, ReferenceType

class LambdaBindingRecipe(NamedTuple):
  pass

class ConstRecipe(NamedTuple):
  val: Any

class JaxprEqnRecipe(NamedTuple):
  prim: Primitive
  tracers_in: list['PartialEvalTracer']
  params: dict[str, Any]
  avals_out: list[ShapedArray]
  tracer_refs_out: list['ReferenceType[PartialEvalTracer]']

JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe]
class PartialEvalTracer(Tracer):
  pval: PartialVal
  recipe: JaxprRecipe | None

  def __init__(self, trace, pval, recipe):
    self._trace = trace
    self.pval = pval
    self.recipe = recipe

  aval = property(lambda self: self.pval.aval)

  def full_lower(self):
    if self.pval.is_known:
      return full_lower(self.pval.const)
    return self

PartialEvalTrace 包含构建 JaxprRecipePartialEvalTracer 图的逻辑。每个参数对应于一个 LambdaBindingRecipe 叶节点, 每个常量是一个 ConstRecipe 叶节点,持有对常量的引用。所有其他的追踪器和食谱来自 process_primitive,它通过 JaxprEqnRecipe 形成追踪器。

对于大多数原语,process_primitive 的逻辑是简单直接的:如果所有输入都是已知的,那么我们可以在已知值上绑定原语(在 Python 中评估它),并避免形成对应于输出的追踪器。如果任何输入未知,那么我们将转而使用表示原语应用的 JaxprEqnRecipe。要构建表示未知输出的追踪器,我们需要 avals,这些 avals 从抽象评估规则中获取。(注意,追踪器引用 JaxprEqnRecipe,而 JaxprEqnRecipe 引用追踪器;我们通过使用 weakrefs 避免循环垃圾。)

这个 process_primitive 的逻辑适用于大多数原语,但 xla_call_p 需要递归处理。因此,我们在 partial_eval_rules 字典中对其规则进行特殊处理。

class PartialEvalTrace(Trace):
  def new_arg(self, pval: PartialVal) -> Any:
    return PartialEvalTracer(self, pval, LambdaBindingRecipe())

  def lift(self, val: Any) -> PartialEvalTracer:
    return PartialEvalTracer(self, PartialVal.known(val), None)
  pure = lift

  def instantiate_const(self, tracer: PartialEvalTracer) -> PartialEvalTracer:
    if tracer.pval.is_unknown:
      return tracer
    else:
      pval = PartialVal.unknown(raise_to_shaped(tracer.aval))
      return PartialEvalTracer(self, pval, ConstRecipe(tracer.pval.const))

  def process_primitive(self, primitive, tracers, params):
    if all(t.pval.is_known for t in tracers):
      return bind(primitive, *map(full_lower, tracers), **params)
    rule = partial_eval_rules.get(primitive)
    if rule: return rule(self, tracers, **params)
    tracers_in = [self.instantiate_const(t) for t in tracers]
    avals_in = [t.aval for t in tracers_in]
    avals_out = abstract_eval_rules[primitive](*avals_in, **params)
    tracers_out = [PartialEvalTracer(self, PartialVal.unknown(aval), None)
                   for aval in avals_out]
    eqn = JaxprEqnRecipe(primitive, tracers_in, params, avals_out,
                         map(ref, tracers_out))
    for t in tracers_out: t.recipe = eqn
    return tracers_out

partial_eval_rules = {}

现在我们可以使用 PartialEvalTrace 构建 jaxprs 的图形表示,我们需要一个机制将图形表示转换为标准的 jaxpr。jaxpr 对应于图的拓扑排序。

def tracers_to_jaxpr(tracers_in: list[PartialEvalTracer],
                     tracers_out: list[PartialEvalTracer]):
  tracer_to_var: dict[int, Var] = {id(t): Var(raise_to_shaped(t.aval))
                                   for t in tracers_in}
  constvar_to_val: dict[int, Any] = {}
  constid_to_var: dict[int, Var] = {}
  processed_eqns: set[int] = set()
  eqns: list[JaxprEqn] = []
  for t in toposort(tracers_out, tracer_parents):
    if isinstance(t.recipe, LambdaBindingRecipe):
      assert id(t) in set(map(id, tracers_in))
    elif isinstance(t.recipe, ConstRecipe):
      val = t.recipe.val
      var = constid_to_var.get(id(val))
      if var is None:
        aval = raise_to_shaped(get_aval(val))
        var = constid_to_var[id(val)] = Var(aval)
        constvar_to_val[var] = val
      tracer_to_var[id(t)] = var
    elif isinstance(t.recipe, JaxprEqnRecipe):
      if id(t.recipe) not in processed_eqns:
        eqns.append(recipe_to_eqn(tracer_to_var, t.recipe))
        processed_eqns.add(id(t.recipe))
    else:
      raise TypeError(t.recipe)

  constvars, constvals = unzip2(constvar_to_val.items())
  in_binders = constvars + [tracer_to_var[id(t)] for t in tracers_in]
  out_vars = [tracer_to_var[id(t)] for t in tracers_out]
  jaxpr = Jaxpr(in_binders, eqns, out_vars)
  typecheck_jaxpr(jaxpr)
  return jaxpr, constvals

def recipe_to_eqn(tracer_to_var: dict[int, Var], recipe: JaxprEqnRecipe
                  ) -> JaxprEqn:
  inputs = [tracer_to_var[id(t)] for t in recipe.tracers_in]
  out_binders = [Var(aval) for aval in recipe.avals_out]
  for t_ref, var in zip(recipe.tracer_refs_out, out_binders):
    if t_ref() is not None: tracer_to_var[id(t_ref())] = var
  return JaxprEqn(recipe.prim, inputs, recipe.params, out_binders)

def tracer_parents(t: PartialEvalTracer) -> list[PartialEvalTracer]:
  return t.recipe.tracers_in if isinstance(t.recipe, JaxprEqnRecipe) else []
Hide code cell source
def toposort(out_nodes: list[Any], parents: Callable[[Any], list[Any]]):
  if not out_nodes: return []
  out_nodes = remove_duplicates(out_nodes)

  child_counts = {}
  stack = list(out_nodes)
  while stack:
    node = stack.pop()
    if id(node) in child_counts:
      child_counts[id(node)] += 1
    else:
      child_counts[id(node)] = 1
      stack.extend(parents(node))
  for node in out_nodes:
    child_counts[id(node)] -= 1

  sorted_nodes = []
  childless_nodes = [node for node in out_nodes if not child_counts[id(node)]]
  while childless_nodes:
    node = childless_nodes.pop()
    sorted_nodes.append(node)
    for parent in parents(node):
      if child_counts[id(parent)] == 1:
        childless_nodes.append(parent)
      else:
        child_counts[id(parent)] -= 1

  sorted_nodes = sorted_nodes[::-1]
  check_toposort(sorted_nodes, parents)
  return sorted_nodes

def remove_duplicates(lst):
  seen = set()
  return [x for x in lst if id(x) not in seen and not seen.add(id(x))]

def check_toposort(nodes: list[Any], parents: Callable[[Any], list[Any]]):
  seen = set()
  for node in nodes:
    assert all(id(parent) in seen for parent in parents(node))
    seen.add(id(node))

现在我们可以进行线性化!

y, sin_lin = linearize(sin, 3.)
print(y, sin(3.))
print(sin_lin(1.), cos(3.))
0.1411200080598672 0.1411200080598672
-0.9899924966004454 -0.9899924966004454

为了处理linearizejit,我们仍然需要为xla_call_p编写一个部分评估规则。除了跟踪书籍管理,主要任务是对jaxpr进行部分评估,将其“解压缩”为两个jaxpr。

实际上需要编写两个规则:一个用于追踪时的部分评估,称为xla_call_partial_eval,另一个用于jaxpr的部分评估,称为xla_call_peval_eqn

def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):
  del num_consts  # 未使用的
  in_unknowns = [not t.pval.is_known for t in tracers]
  jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns)
  known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)
  known_vals = [t.pval.const for t in known_tracers]
  outs1_res = bind(xla_call_p, *known_vals, jaxpr=jaxpr1, num_consts=0)
  outs1, res = split_list(outs1_res, len(jaxpr1.outs) - num_res)
  res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]
  outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None)
           for v in jaxpr2.outs]
  eqn = JaxprEqnRecipe(xla_call_p, res_tracers + unknown_tracers,
                       dict(jaxpr=jaxpr2, num_consts=0),
                       [v.aval for v in jaxpr2.outs], map(ref, outs2))
  for t in outs2: t.recipe = eqn
  return merge_lists(out_unknowns, outs1, outs2)
partial_eval_rules[xla_call_p] = xla_call_partial_eval

def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: list[bool],
                       instantiate: list[bool] | None = None,
                       ) -> tuple[Jaxpr, Jaxpr, list[bool], int]:
  env: dict[Var, bool] = {}
  residuals: set[Var] = set()

  def read(x: Atom) -> bool:
    return type(x) is Var and env[x]

  def write(unk: bool, v: Var) -> None:
    env[v] = unk

  def new_res(x: Atom) -> Atom:
    if type(x) is Var: residuals.add(x)
    return x

  eqns1, eqns2 = [], []
  map(write, in_unknowns, jaxpr.in_binders)
  for eqn in jaxpr.eqns:
    unks_in = map(read, eqn.inputs)
    rule = partial_eval_jaxpr_rules.get(eqn.primitive)
    if rule:
      eqn1, eqn2, unks_out, res = rule(unks_in, eqn)
      eqns1.append(eqn1); eqns2.append(eqn2); residuals.update(res)
      map(write, unks_out, eqn.out_binders)
    elif any(unks_in):
      inputs = [v if unk else new_res(v) for unk, v in zip(unks_in, eqn.inputs)]
      eqns2.append(JaxprEqn(eqn.primitive, inputs, eqn.params, eqn.out_binders))
      map(partial(write, True), eqn.out_binders)
    else:
      eqns1.append(eqn)
      map(partial(write, False), eqn.out_binders)
  out_unknowns = map(read, jaxpr.outs)
  if instantiate is not None:
    for v, uk, inst in zip(jaxpr.outs, out_unknowns, instantiate):
      if inst and not uk: new_res(v)
    out_unknowns = map(op.or_, out_unknowns, instantiate)

  residuals, num_res = list(residuals), len(residuals)
  assert all(type(v) is Var for v in residuals), residuals

  ins1, ins2 = partition_list(in_unknowns, jaxpr.in_binders)
  outs1, outs2 = partition_list(out_unknowns, jaxpr.outs)

  jaxpr1 = Jaxpr(ins1, eqns1, outs1 + residuals)
  jaxpr2 = Jaxpr(residuals + ins2, eqns2, outs2)
  typecheck_partial_eval_jaxpr(jaxpr, in_unknowns, out_unknowns, jaxpr1, jaxpr2)

  return jaxpr1, jaxpr2, out_unknowns, num_res

def typecheck_partial_eval_jaxpr(jaxpr, unks_in, unks_out, jaxpr1, jaxpr2):
  jaxprty = typecheck_jaxpr(jaxpr)    # (a1, a2) -> (b1, b2)
  jaxpr1ty = typecheck_jaxpr(jaxpr1)  # a1       -> (b1, res)
  jaxpr2ty = typecheck_jaxpr(jaxpr2)  # (res, a2) -> b2

  a1, a2 = partition_list(unks_in, jaxprty.in_types)
  b1, b2 = partition_list(unks_out, jaxprty.out_types)
  b1_, res = split_list(jaxpr1ty.out_types, len(b1))
  res_, a2_ = split_list(jaxpr2ty.in_types, len(res))
  b2_ = jaxpr2ty.out_types

  if jaxpr1ty.in_types != a1: raise TypeError
  if jaxpr2ty.out_types != b2: raise TypeError
  if b1 != b1_: raise TypeError
  if res != res_: raise TypeError
  if a2 != a2_: raise TypeError
  if b2 != b2_: raise TypeError

partial_eval_jaxpr_rules = {}

def xla_call_peval_eqn(unks_in: list[bool], eqn: JaxprEqn,
                       ) -> tuple[JaxprEqn, JaxprEqn, list[bool], list[Var]]:
  jaxpr = eqn.params['jaxpr']
  jaxpr1, jaxpr2, unks_out, num_res = partial_eval_jaxpr(jaxpr, unks_in)
  ins1, ins2 = partition_list(unks_in, eqn.inputs)
  out_binders1, out_binders2 = partition_list(unks_out, eqn.out_binders)
  residuals = [Var(v.aval) for v in jaxpr2.in_binders[:num_res]]
  eqn1 = JaxprEqn(xla_call_p, ins1, dict(jaxpr=jaxpr1, num_consts=0),
                  out_binders1 + residuals)
  eqn2 = JaxprEqn(xla_call_p, residuals + ins2,
                  dict(jaxpr=jaxpr2, num_consts=0), out_binders2)
  return eqn1, eqn2, unks_out, residuals
partial_eval_jaxpr_rules[xla_call_p] = xla_call_peval_eqn

因此,我们可以根据需要组合 linearizejit

@jit
def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z

y, f_lin = linearize(f, 3.)
y_dot = f_lin(1.)
print(y, y_dot)
2.7177599838802657 2.979984993200891
@jit
def f(x):
  y = sin(x) * 2.
  z = g(x, y)
  return z

@jit
def g(x, y):
  return cos(x) + y

y, f_lin = linearize(f, 3.)
y_dot = f_lin(1.)
print(y, y_dot)
-0.7077524804807109 -2.121105001260758

vjpgrad#

vjp 转换与 linearize 非常相似。它的类型签名是类似的:

linearize : (a -> b) -> a -> (b, T a -o T b)
vjp       : (a -> b) -> a -> (b, T b -o T a)

唯一的区别是我们在返回结果之前转置计算的线性部分,这样它就从类型 T a -o T b 变成了类型 T b -o T a。也就是说,我们将 vjp 实现为本质上如下:

def vjp(f, x):
  y, f_lin = linearize(f, x)
  f_vjp = lambda y_bar: transpose(f_lin)(y_bar)
  return y, f_vjp

因为我们将线性计算作为 jaxpr,而不仅仅是 Python 可调用对象,我们可以将转置转换实现为 jaxpr 解释器。

def vjp_flat(f, *primals_in):
  pvals_in = ([PartialVal.known(x) for x in primals_in] +
              [PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])
  primal_pvals_in, tangent_pvals_in = split_half(pvals_in)
  def f_jvp(*primals_tangents_in):
    primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))
    return [*primals_out, *tangents_out]
  jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in)  # 线性化
  primal_pvals, _ = split_half(pvals_out)
  assert all(pval.is_known for pval in primal_pvals)
  primals_out = [pval.const for pval in primal_pvals]
  transpose_inputs = consts + [UndefPrimal(p.aval) for p in tangent_pvals_in]
  f_vjp = lambda *cts: eval_jaxpr_transposed(jaxpr, transpose_inputs, cts)
  return primals_out, f_vjp

def vjp(f, *primals_in):
  primals_in_flat, in_tree = tree_flatten(primals_in)
  f, out_tree = flatten_fun(f, in_tree)
  primals_out_flat, f_vjp_flat = vjp_flat(f, *primals_in_flat)
  primals_out = tree_unflatten(out_tree(), primals_out_flat)

  def f_vjp(*cotangents_out):
    cotangents_out_flat, _ = tree_flatten(cotangents_out)
    cotangents_in_flat = f_vjp_flat(*cotangents_out_flat)
    return tree_unflatten(in_tree, cotangents_in_flat)

  return primals_out, f_vjp

class UndefPrimal(NamedTuple):
  aval: ShapedArray

register_pytree_node(UndefPrimal,
                     lambda u: (u.aval, ()),
                     lambda aval, _: UndefPrimal(aval))

我们使用 UndefPrimal 实例来指示我们希望对哪些参数进行转置。这是因为一般而言,关于封闭值,我们希望将类型为 a -> b -o c 的函数转置为类型为 a -> c -o b 的函数。更一般来说,函数的线性输入可能散布在参数列表中。因此,我们使用 UndefPrimal 来指示线性位置。我们将 UndefPrimal 注册为一个 pytree 节点,因为 pytree 机制提供了一种方便的方法来从参数列表中修剪掉这些占位符。

接下来,我们可以编写 eval_jaxpr_transposed,以及所有可以在至少一个参数上线性的原语的转置规则:

# 注意:JAX 中类似的函数称为 'backward_pass'
def eval_jaxpr_transposed(jaxpr: Jaxpr, args: list[Any], cotangents: list[Any]
                          ) -> list[Any]:
  primal_env: dict[Var, Any] = {}
  ct_env: dict[Var, Any] = {}

  def read_primal(x: Atom) -> Any:
    return primal_env.get(x, UndefPrimal(x.aval)) if type(x) is Var else x.val

  def write_primal(v: Var, val: Any) -> None:
    if type(val) is not UndefPrimal:
      primal_env[v] = val

  def read_cotangent(v: Var) -> Any:
    return ct_env.pop(v, np.zeros(v.aval.shape, v.aval.dtype))

  def write_cotangent(x: Atom, val: Any):
    if type(x) is Var and val is not None:
      ct_env[x] = add(ct_env[x], val) if x in ct_env else val

  map(write_primal, jaxpr.in_binders, args)
  map(write_cotangent, jaxpr.outs, cotangents)
  for eqn in jaxpr.eqns[::-1]:
    primals_in = map(read_primal, eqn.inputs)
    cts_in = map(read_cotangent, eqn.out_binders)
    rule = transpose_rules[eqn.primitive]
    cts_out = rule(cts_in, *primals_in, **eqn.params)
    map(write_cotangent, eqn.inputs, cts_out)

  return [read_cotangent(v) for v, x in zip(jaxpr.in_binders, args)
          if type(x) is UndefPrimal]

transpose_rules = {}
def mul_transpose_rule(cts, x, y):
  z_bar, = cts
  assert (type(x) is UndefPrimal) ^ (type(y) is UndefPrimal)
  return [mul(z_bar, y), None] if type(x) is UndefPrimal else [None, mul(x, z_bar)]
transpose_rules[mul_p] = mul_transpose_rule

def neg_transpose_rule(cts, x):
  ybar, = cts
  assert type(x) is UndefPrimal
  return [neg(ybar)]
transpose_rules[neg_p] = neg_transpose_rule

def add_transpose_rule(cts, x, y):
  z_bar, = cts
  return [z_bar, z_bar]
transpose_rules[add_p] = add_transpose_rule

def reduce_sum_transpose_rule(cts, x, *, axis):
  y_bar, = cts
  return [broadcast(y_bar, x.aval.shape, axis)]
transpose_rules[reduce_sum_p] = reduce_sum_transpose_rule

def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):
  del num_consts  # 未使用
  undef_primals = [type(x) is UndefPrimal for x in invals]
  transposed_jaxpr, new_consts = transpose_jaxpr(jaxpr, tuple(undef_primals))
  residuals, _ = partition_list(undef_primals, invals)
  outs = bind(xla_call_p, *new_consts, *residuals, *cts,
              jaxpr=transposed_jaxpr, num_consts=len(new_consts))
  outs = iter(outs)
  return [next(outs) if undef else None for undef in undef_primals]
transpose_rules[xla_call_p] = xla_call_transpose_rule

@lru_cache
def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: tuple[bool, ...]
                    ) -> tuple[Jaxpr, list[Any]]:
  avals_in, avals_out = typecheck_jaxpr(jaxpr)
  traceable = partial(eval_jaxpr_transposed, jaxpr)
  args = [UndefPrimal(a) if u else a for a, u in zip(avals_in, undef_primals)]
  trans_jaxpr, consts, _ = make_jaxpr(traceable, tuple(args), tuple(avals_out))
  typecheck_jaxpr(trans_jaxpr)
  return trans_jaxpr, consts

现在我们可以进行线性化和转置,最终可以编写 grad

def grad(f):
  def gradfun(x, *xs):
    y, f_vjp = vjp(f, x, *xs)
    if np.shape(y) != (): raise TypeError
    x_bar, *_ = f_vjp(np.ones(np.shape(y), np.result_type(y)))
    return x_bar
  return gradfun
y, f_vjp = vjp(sin, 3.)
print(f_vjp(1.), cos(3.))
(np.float64(-0.9899924966004454),) -0.9899924966004454
def f(x):
  y = sin(x) * 2.
  z = - y + x
  return z

print(grad(f)(3.))
2.979984993200891
@jit
def f(x):
  y = x * 2.
  z = g(y)
  return z

@jit
def g(x):
  return cos(x) * 2.

print(grad(f)(3.))
1.1176619927957034

这里有一个组合性压力测试:

# 来自 core_test.py 的嵌套调用功能测试 2
def foo(x):
  @jit
  def bar(y):
    def baz(w):
      q = jit(lambda x: y)(x)
      q = q + jit(lambda: y)()
      q = q + jit(lambda y: w + y)(y)
      q = jit(lambda w: jit(sin)(x) * y)(1.0) + q
      return q
    p, t = jvp(baz, (x + 1.0,), (y,))
    return t + (x * p)
  return bar(x)

def assert_allclose(*vals):
  for v1, v2 in zip(vals[:-1], vals[1:]):
    np.testing.assert_allclose(v1, v2)

ans1 = f(3.)
ans2 = jit(f)(3.)
ans3, _ = jvp(f, (3.,), (5.,))
ans4, _ = jvp(jit(f), (3.,), (5.,))
assert_allclose(ans1, ans2, ans3, ans4)

deriv1 = grad(f)(3.)
deriv2 = grad(jit(f))(3.)
deriv3 = jit(grad(jit(f)))(3.)
_, deriv4 = jvp(f, (3.,), (1.,))
_, deriv5 = jvp(jit(f), (3.,), (1.,))
assert_allclose(deriv1, deriv2, deriv3, deriv4, deriv5)

hess1 = grad(grad(f))(3.)
hess2 = grad(grad(jit(f)))(3.)
hess3 = grad(jit(grad(f)))(3.)
hess4 = jit(grad(grad(f)))(3.)
_, hess5 = jvp(grad(f), (3.,), (1.,))
_, hess6 = jvp(jit(grad(f)), (3.,), (1.,))
_, hess7 = jvp(jit(grad(f)), (3.,), (1.,))
assert_allclose(hess1, hess2, hess3, hess4, hess5, hess6, hess7)

第五部分:控制流原语 cond#

接下来,我们将添加用于分阶段控制流的高阶原语。这些原语类似于第三部分中的 jit,另一种高阶原语,但不同的是,它们是由多个可调用对象而不是只有一个参数化的。

添加 cond#

我们引入一个 cond 原语来表示在 jaxpr 内部条件性地应用一个函数或另一个函数。我们将 cond 的类型写为 Bool -> (a -> b) -> (a -> b) -> a -> b。换句话说,cond 接受一个布尔值作为谓词和两个相同类型的函数。根据谓词的值,它将一个函数或另一个函数应用于其最终参数。

在 Python 中,我们将其表示为一个函数,该函数本身接受两个函数作为参数。与 jit 一样,第一步是对其可调用参数调用 make_jaxpr 将它们转换为 jaxprs:

def cond(pred, true_fn, false_fn, *operands):
  avals_in = [raise_to_shaped(get_aval(x)) for x in operands]
  true_jaxpr, true_consts, out_tree = make_jaxpr(true_fn, *avals_in)
  false_jaxpr, false_consts, out_tree_ = make_jaxpr(false_fn, *avals_in)
  if out_tree != out_tree_: raise TypeError
  true_jaxpr, false_jaxpr = _join_jaxpr_consts(
      true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
  if typecheck_jaxpr(true_jaxpr) != typecheck_jaxpr(false_jaxpr):
    raise TypeError
  outs = bind_cond(pred, *true_consts, *false_consts, *operands,
                   true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
  return tree_unflatten(out_tree, outs)
cond_p = Primitive('cond')

def _join_jaxpr_consts(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int
                       ) -> tuple[Jaxpr, Jaxpr]:
  jaxpr1_type, jaxpr2_type = typecheck_jaxpr(jaxpr1), typecheck_jaxpr(jaxpr2)
  assert jaxpr1_type.in_types[n1:] == jaxpr2_type.in_types[n2:]
  consts1, rest1 = split_list(jaxpr1.in_binders, n1)
  consts2, rest2 = split_list(jaxpr2.in_binders, n2)
  new_jaxpr1 = Jaxpr(consts1 + consts2 + rest1, jaxpr1.eqns, jaxpr1.outs)
  new_jaxpr2 = Jaxpr(consts1 + consts2 + rest2, jaxpr2.eqns, jaxpr2.outs)
  return new_jaxpr1, new_jaxpr2

def bind_cond(pred, *args, true_jaxpr, false_jaxpr):
  assert len(args) == len(true_jaxpr.in_binders) == len(false_jaxpr.in_binders)
  return bind(cond_p, pred, *args, true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)

我们要求 true_jaxprfalse_jaxpr 具有相同的类型,但因为它们可能覆盖不同的常量(并且因为 jaxpr 只能表示封闭的项,即不能有自由变量,而是进行闭包转换),我们需要使用帮助器 _join_jaxpr_consts 来统一两个 jaxpr 的输入绑定列表。 (为了更经济,我们可以尝试识别形状相同的常量对,但我们只连接常量列表。)

接下来,我们可以着手为 cond 添加解释器规则。它的评估规则很简单:

def cond_impl(pred, *operands, true_jaxpr, false_jaxpr):
  if pred:
    return eval_jaxpr(true_jaxpr, operands)
  else:
    return eval_jaxpr(false_jaxpr, operands)
impl_rules[cond_p] = cond_impl
out = cond(True, lambda: 3, lambda: 4)
print(out)
3

对于它的 JVP 和 vmap 规则,我们只需要调用为 jit 创建的相同的 jvp_jaxprvmap_jaxpr 工具,然后再执行一次 _join_jaxpr_consts

def cond_jvp_rule(primals, tangents, *, true_jaxpr, false_jaxpr):
  pred, *primals = primals
  _   , *tangents = tangents
  true_jaxpr , true_consts  = jvp_jaxpr(true_jaxpr)
  false_jaxpr, false_consts = jvp_jaxpr(false_jaxpr)
  true_jaxpr, false_jaxpr = _join_jaxpr_consts(
      true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
  assert typecheck_jaxpr(true_jaxpr) == typecheck_jaxpr(false_jaxpr)
  outs = bind_cond(pred, *true_consts, *false_consts, *primals, *tangents,
                   true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
  primals_out, tangents_out = split_half(outs)
  return primals_out, tangents_out
jvp_rules[cond_p] = cond_jvp_rule
out, out_tan = jvp(lambda x: cond(True, lambda: x * x, lambda: 0.), (1.,), (1.,))
print(out_tan)
2.0
def cond_vmap_rule(axis_size, vals_in, dims_in, *, true_jaxpr, false_jaxpr):
  pred    , *vals_in = vals_in
  pred_dim, *dims_in = dims_in
  if pred_dim is not not_mapped: raise NotImplementedError  # 待办事项
  true_jaxpr, true_consts = vmap_jaxpr(true_jaxpr, axis_size, tuple(dims_in))
  false_jaxpr, false_consts = vmap_jaxpr(false_jaxpr, axis_size, tuple(dims_in))
  true_jaxpr, false_jaxpr = _join_jaxpr_consts(
      true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
  assert typecheck_jaxpr(true_jaxpr) == typecheck_jaxpr(false_jaxpr)
  outs = bind_cond(pred, *true_consts, *false_consts, *vals_in,
                   true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
  return outs, [0] * len(outs)
vmap_rules[cond_p] = cond_vmap_rule
xs = np.array([1., 2., 3])
out = vmap(lambda x: cond(True, lambda: x + 1., lambda: 0.), (0,))(xs)
print(out)
[2. 3. 4.]

注意,我们目前不支持谓词值本身被批处理的情况。在主线 JAX 中,我们通过将条件转换为一个 select 原语 来处理这种情况。只要 true_funfalse_fun 不涉及任何有副作用的原语,这种转换在语义上是正确的。

这里没有表示的另一个情况,但在主线 JAX 中存在的是,对两个相同类型的 jaxpr 应用转换可能会导致不同类型的 jaxpr。例如,将主线 JAX 版本的 vmap_jaxpr 应用于恒等函数 jaxpr

{ lambda a:float32[] .
  let
  in ( a ) }

将会生成一个批处理输出的 jaxpr,类型为 [float32[10]] -> [float32[10]],如果批处理大小为 10,而将其应用于零函数 jaxpr

{ lambda a:float32[] .
  let
  in ( 0. ) }

将会生成一个未批处理输出的 jaxpr,类型为 [float32[10]] -> [float32[]]。这是一种优化,旨在避免不必要的值批处理。但这意味着在 cond 中,我们需要额外一步将两个转换后的 jaxpr 结合,以保持输出类型一致。这里不需要这一步,因为我们选择了 vmap_jaxpr 始终在前导轴上批处理所有输出。

接下来我们可以转向抽象评估和XLA降级规则:

def cond_abstract_eval(pred_type, *in_types, true_jaxpr, false_jaxpr):
  if pred_type != ShapedArray((), np.dtype('bool')): raise TypeError
  jaxpr_type = typecheck_jaxpr(true_jaxpr)
  if jaxpr_type != typecheck_jaxpr(false_jaxpr):
    raise TypeError
  if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):
    raise TypeError
  return jaxpr_type.out_types
abstract_eval_rules[cond_p] = cond_abstract_eval

def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr):
  del in_avals  # 未使用的
  pred, *in_vals = in_vals
  flat_vals, in_tree = tree_flatten(in_vals)
  operand = xops.Tuple(c, flat_vals)
  operand_shape = c.get_shape(operand)

  def make_comp(name: str, jaxpr: Jaxpr) -> xe.XlaComputation:
    c = xc.XlaBuilder(name)
    operand = xops.Parameter(c, 0, operand_shape)
    operands = tree_unflatten(in_tree, destructure_tuple(c, operand))
    outs = jaxpr_subcomp(c, jaxpr, operands)
    return c.build(xops.Tuple(c, outs))

  true_comp = make_comp('true_fn', true_jaxpr)
  false_comp = make_comp('false_fn', false_jaxpr)

  int_etype = xc.dtype_to_etype(np.dtype('int32'))
  out = xops.Conditional(xops.ConvertElementType(pred, int_etype),
                         [false_comp, true_comp], [operand] * 2)
  return destructure_tuple(c, out)
xla_translations[cond_p] = cond_translation
out = jit(lambda: cond(False, lambda: 1, lambda: 2))()
print(out)
2

最后,为了支持反向模式自动微分,我们需要部分求值和转置规则。对于部分求值,我们需要引入另一个 jaxpr 处理工具 _join_jaxpr_res,来处理对 true_funfalse_fun 应用部分求值通常会导致不同剩余的情况。我们使用 _join_jaxpr_res 来使变换后的 jaxpr 的输出类型一致(而 _join_jaxpr_consts 处理的是输入类型)。

def cond_partial_eval(trace, tracers, *, true_jaxpr, false_jaxpr):
  pred_tracer, *tracers = tracers
  assert pred_tracer.pval.is_known
  pred = pred_tracer.pval.const
  in_uks = [not t.pval.is_known for t in tracers]

  *jaxprs, out_uks, num_res = _cond_partial_eval(true_jaxpr, false_jaxpr, in_uks)
  t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2 = jaxprs

  known_tracers, unknown_tracers = partition_list(in_uks, tracers)
  known_vals = [t.pval.const for t in known_tracers]
  outs1_res = bind_cond(pred, *known_vals,
                        true_jaxpr=t_jaxpr1, false_jaxpr=f_jaxpr1)
  outs1, res = split_list(outs1_res, len(outs1_res) - num_res)
  pred_tracer_ = trace.instantiate_const(full_raise(trace, pred_tracer))
  res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]
  outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None)
           for v in t_jaxpr2.outs]
  eqn = JaxprEqnRecipe(cond_p, [pred_tracer_, *res_tracers, *unknown_tracers],
                       dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),
                       [v.aval for v in t_jaxpr2.outs], map(ref, outs2))
  for t in outs2: t.recipe = eqn
  return merge_lists(out_uks, outs1, outs2)
partial_eval_rules[cond_p] = cond_partial_eval

def _cond_partial_eval(true_jaxpr: Jaxpr, false_jaxpr: Jaxpr, in_uks: list[bool]
                       ) -> tuple[Jaxpr, Jaxpr, Jaxpr, Jaxpr, list[bool], int]:
  _, _, t_out_uks, _ = partial_eval_jaxpr(true_jaxpr , in_uks)
  _, _, f_out_uks, _ = partial_eval_jaxpr(false_jaxpr, in_uks)
  out_uks = map(op.or_, t_out_uks, f_out_uks)

  t_jaxpr1, t_jaxpr2, _, t_nres = partial_eval_jaxpr(true_jaxpr , in_uks, out_uks)
  f_jaxpr1, f_jaxpr2, _, f_nres = partial_eval_jaxpr(false_jaxpr, in_uks, out_uks)

  t_jaxpr1, f_jaxpr1 = _join_jaxpr_res(t_jaxpr1, f_jaxpr1, t_nres, f_nres)
  t_jaxpr2, f_jaxpr2 = _join_jaxpr_consts(t_jaxpr2, f_jaxpr2, t_nres, f_nres)
  assert typecheck_jaxpr(t_jaxpr1) == typecheck_jaxpr(f_jaxpr1)
  assert typecheck_jaxpr(t_jaxpr2) == typecheck_jaxpr(f_jaxpr2)
  num_res = t_nres + f_nres

  return t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2, out_uks, num_res

def _join_jaxpr_res(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int
                    ) -> tuple[Jaxpr, Jaxpr]:
  jaxpr1_type, jaxpr2_type = typecheck_jaxpr(jaxpr1), typecheck_jaxpr(jaxpr2)
  out_types1, _ = split_list(jaxpr1_type.out_types, len(jaxpr1.outs) - n1)
  out_types2, _ = split_list(jaxpr2_type.out_types, len(jaxpr2.outs) - n2)
  assert out_types1 == out_types2
  outs1, res1 = split_list(jaxpr1.outs, len(jaxpr1.outs) - n1)
  outs2, res2 = split_list(jaxpr2.outs, len(jaxpr2.outs) - n2)
  zeros_like1 = [Lit(np.zeros(v.aval.shape, v.aval.dtype)) for v in res1]
  zeros_like2 = [Lit(np.zeros(v.aval.shape, v.aval.dtype)) for v in res2]
  new_jaxpr1 = Jaxpr(jaxpr1.in_binders, jaxpr1.eqns, outs1 + res1 + zeros_like2)
  new_jaxpr2 = Jaxpr(jaxpr2.in_binders, jaxpr2.eqns, outs2 + zeros_like1 + res2)
  return new_jaxpr1, new_jaxpr2
_, f_lin = linearize(lambda x: cond(True, lambda: x, lambda: 0.), 1.)
out = f_lin(3.14)
print(out)
3.14
def cond_peval_eqn(unks_in: list[bool], eqn: JaxprEqn,
                   ) -> tuple[JaxprEqn, JaxprEqn, list[bool], list[Atom]]:
  pred_unk, *unks_in = unks_in
  assert not pred_unk
  true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']
  *jaxprs, unks_out, num_res = _cond_partial_eval(true_jaxpr, false_jaxpr, unks_in)
  t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2 = jaxprs
  ins1, ins2 = partition_list(unks_in, eqn.inputs[1:])
  outs1, outs2 = partition_list(unks_out, eqn.out_binders)
  residuals, _ = split_list(t_jaxpr2.in_binders, num_res)
  eqn1 = JaxprEqn(cond_p, [eqn.inputs[0], *ins1],
                  dict(true_jaxpr=t_jaxpr1, false_jaxpr=f_jaxpr1),
                  outs1 + residuals)
  eqn2 = JaxprEqn(cond_p, [eqn.inputs[0], *residuals, *ins2],
                  dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),
                  outs2)
  res = [eqn.inputs[0], *residuals] if type(eqn.inputs[0]) is Var else residuals
  return eqn1, eqn2, unks_out, res
partial_eval_jaxpr_rules[cond_p] = cond_peval_eqn
_, f_lin = linearize(jit(lambda x: cond(True, lambda: x, lambda: 0.)), 1.)
out = f_lin(3.14)
print(out)
3.14

转置是对 transpose_jaxpr 的一种相当直接的应用:

def cond_transpose_rule(cts, pred, *invals, true_jaxpr, false_jaxpr):
  undef_primals = tuple(type(x) is UndefPrimal for x in invals)
  true_jaxpr, true_consts = transpose_jaxpr(true_jaxpr, undef_primals)
  false_jaxpr, false_consts = transpose_jaxpr(false_jaxpr, undef_primals)
  true_jaxpr, false_jaxpr = _join_jaxpr_consts(
      true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))
  res = [x for x in invals if type(x) is not UndefPrimal]
  outs = bind_cond(pred, *true_consts, *false_consts, *res, *cts,
                   true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)
  outs = iter(outs)
  return [None] + [next(outs) if type(x) is UndefPrimal else None for x in invals]
transpose_rules[cond_p] = cond_transpose_rule
out = grad(lambda x: cond(True, lambda: x * x, lambda: 0.))(1.)
print(out)
2.0
Hide code cell source
def pprint_cond(names: defaultdict[Var, str], eqn: JaxprEqn) -> PPrint:
  true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']
  new_params = {k:v for k, v in eqn.params.items() if not k.endswith('jaxpr')}
  lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))
  rhs = (pp(eqn.primitive.name) >> pp_params(new_params) >>
         pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)
                     for x in eqn.inputs)))
  return vcat([lhs >> pp(' = ') >> rhs,
               pp_jaxpr(true_jaxpr).indent(2),
               pp_jaxpr(false_jaxpr).indent(2)])
pp_rules[cond_p] = pprint_cond