JAX原语是如何工作的#

在Colab中打开 在Kaggle中打开

necula@google.com, 2019年10月。

JAX实现了某些Python函数的变换,例如jitgradvmappmap。需要被转化的Python函数必须是JAX可追踪的,这意味着在Python函数执行时,对数据进行的唯一操作要么是对数据属性的检查,如形状或类型,要么是称为JAX原语的特殊操作。特别是,JAX有时会通过抽象参数调用可追踪的JAX函数。JAX抽象值的一个例子是ShapedArray(float32[2,2]),它捕获了值的类型和形状,但不包括具体的数据值。JAX原语知道如何处理具体的数据值和JAX抽象值。

通过JAX变换的函数本身必须是JAX可追踪的函数,以确保这些变换可以组合,例如jit(jacfwd(grad(f)))

大多数XLA操作对应的JAX原语都是预定义的,例如加法、矩阵乘法、正弦、余弦、索引。JAX提供了以JAX原语实现的numpy函数的实现,这意味着使用JAX实现的numpy的Python程序是JAX可追踪的,因此是可变换的。通过将它们实现为JAX原语,其他库也可以变得JAX可追踪。

JAX原语的集合是可扩展的。与其重新实现某个函数以预定义的JAX原语,不如定义一个新原语来封装该函数的行为。

本文档的目标是解释JAX原语必须支持的接口,以允许JAX执行其所有变换。

考虑我们希望为JAX添加一个具有三个参数的乘加函数的支持,该函数在数学上定义为“multiply_add(x, y, z) = x * y + z”。该函数对3个形状相同的浮点值张量进行操作,并逐点执行这些操作。

使用现有原语#

定义新函数的最简单方法是使用JAX原语,或使用其他基于JAX原语的函数来编写它们,例如,jax.lax模块中定义的那些函数:

from jax import lax
from jax._src import api

def multiply_add_lax(x, y, z):
  """使用jax.lax原语实现乘加操作。"""
  return lax.add(lax.mul(x, y), z)


def square_add_lax(a, b):
  """使用新定义的乘加运算的平方加函数。"""
  return multiply_add_lax(a, a, b)

print("square_add_lax = ", square_add_lax(2., 10.))
# 对第一个参数求导
print("grad(square_add_lax) = ", api.grad(square_add_lax, argnums=0)(2.0, 10.))
square_add_lax =  14.0
grad(square_add_lax) =  4.0

为了理解JAX是如何在内部使用原始函数的,我们添加了一些用于跟踪函数调用的辅助工具。

#@标题 辅助函数(执行此单元格)
import functools
import traceback

_indentation = 0
def _trace(msg=None):
    """在当前缩进处打印一条消息。"""
    if msg is not None:
        print("  " * _indentation + msg)

def _trace_indent(msg=None):
    """打印一条消息,然后缩进其余部分。"""
    global _indentation
    _trace(msg)
    _indentation = 1 + _indentation

def _trace_unindent(msg=None):
    """取消缩进后打印一条消息。"""
    global _indentation
    _indentation = _indentation - 1
    _trace(msg)

def trace(name):
  """用于跟踪函数参数和结果的装饰器。"""

  def trace_func(func):  # pylint: 禁用=缺少文档字符串
    def pp(v):
        """更简洁地打印某些值"""
        vtype = str(type(v))
        if "jax._src.xla_bridge._JaxComputationBuilder" in vtype:
            return "<JaxComputationBuilder>"
        elif "jaxlib.xla_extension.XlaOp" in vtype:
            return "<XlaOp at 0x{:x}>".format(id(v))
        elif ("partial_eval.JaxprTracer" in vtype or
              "batching.BatchTracer" in vtype or
              "ad.JVPTracer" in vtype):
            return "Traced<{}>".format(v.aval)
        elif isinstance(v, tuple):
            return "({})".format(pp_values(v))
        else:
            return str(v)
    def pp_values(args):
        return ", ".join([pp(arg) for arg in args])

    @functools.wraps(func)
    def func_wrapper(*args):
      _trace_indent("call {}({})".format(name, pp_values(args)))
      res = func(*args)
      _trace_unindent("|<- {} = {}".format(name, pp(res)))
      return res

    return func_wrapper

  return trace_func

class expectNotImplementedError(object):
  """用于检查NotImplementedError的上下文管理器。"""
  def __enter__(self): pass
  def __exit__(self, type, value, tb):
    global _indentation
    _indentation = 0
    if type is NotImplementedError:
      print("\nFound expected exception:")
      traceback.print_exc(limit=3)
      return True
    elif type is None:  # 也不例外
      assert False, "Expected NotImplementedError"
    else:
      return False

我们可以使用已经用这些原语写好的其他函数,而不是直接使用 jax.lax 原语,例如 jax.numpy 中的函数:

import jax.numpy as jnp
import numpy as np

@trace("multiply_add_numpy")
def multiply_add_numpy(x, y, z):
    return jnp.add(jnp.multiply(x, y), z)

@trace("square_add_numpy")
def square_add_numpy(a, b):
    return multiply_add_numpy(a, a, b)

print("\nNormal evaluation:")
print("square_add_numpy = ", square_add_numpy(2., 10.))
print("\nGradient evaluation:")
print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.))
Normal evaluation:
call square_add_numpy(2.0, 10.0)
  call multiply_add_numpy(2.0, 2.0, 10.0)
  |<- multiply_add_numpy = 14.0
|<- square_add_numpy = 14.0
square_add_numpy =  14.0

Gradient evaluation:
call square_add_numpy(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
  call multiply_add_numpy(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
  |<- multiply_add_numpy = Traced<ConcreteArray(14.0, dtype=float32, weak_type=True)>
|<- square_add_numpy = Traced<ConcreteArray(14.0, dtype=float32, weak_type=True)>
grad(square_add_numpy) =  4.0

注意,在计算 grad 的过程中,JAX 使用特殊参数 ConcreteArray(...) 调用 square_add_numpymultiply_add_numpy(在本 Colab 中稍后会进一步描述)。重要的是要记住,一个可被 JAX 跟踪的函数必须能够不仅在具体参数上操作,还能在 JAX 可能用来抽象函数执行的特殊抽象参数上操作。

只要该函数是用 JAX 原语编写的,就满足 JAX 跟踪性属性。

定义新的JAX原语#

添加对乘加的支持的正确方式是通过现有的JAX原语,如上所示。然而,为了演示JAX原语的工作方式,让我们假装我们想要为乘加功能向JAX添加一个新的原语。

from jax import core
multiply_add_p = core.Primitive("multiply_add")  # 创建原始

@trace("multiply_add_prim")
def multiply_add_prim(x, y, z):
  """使用JAX原语的可追踪方式。

请注意,被追踪的参数必须作为位置参数传递给`bind`。
  """
  return multiply_add_p.bind(x, y, z)

@trace("square_add_prim")
def square_add_prim(a, b):
  """使用新JAX原语实现的平方加法函数。"""
  return multiply_add_prim(a, a, b)

如果我们尝试调用新定义的函数,则会出现错误,因为我们尚未告诉JAX有关新原语的任何语义。

with expectNotImplementedError():
  square_add_prim(2., 10.)
call square_add_prim(2.0, 10.0)
  call multiply_add_prim(2.0, 2.0, 10.0)

Found expected exception:
Traceback (most recent call last):
  File "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/158865747.py", line 2, in <module>
    square_add_prim(2., 10.)
  File "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py", line 48, in func_wrapper
    res = func(*args)
          ^^^^^^^^^^^
  File "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py", line 15, in square_add_prim
    return multiply_add_prim(a, a, b)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
NotImplementedError: Evaluation rule for 'multiply_add' not implemented

原始评估规则#

@trace("multiply_add_impl")
def multiply_add_impl(x, y, z):
  """Concrete implementation of the primitive.

  This function does not need to be JAX traceable.
  Args:
    x, y, z: the concrete arguments of the primitive. Will only be called with
      concrete values.
  Returns:
    the concrete result of the primitive.
  """
  # 请注意,我们可以使用原始的 numpy,它不是 JAX 可追踪的。
  return np.add(np.multiply(x, y), z)

# 现在我们将原始实现注册到 JAX 中。
multiply_add_p.def_impl(multiply_add_impl)
<function __main__.multiply_add_impl(x, y, z)>
assert square_add_prim(2., 10.) == 14.
call square_add_prim(2.0, 10.0)
  call multiply_add_prim(2.0, 2.0, 10.0)
    call multiply_add_impl(2.0, 2.0, 10.0)
    |<- multiply_add_impl = 14.0
  |<- multiply_add_prim = 14.0
|<- square_add_prim = 14.0

JIT#

如果我们现在尝试使用 jit,我们会得到一个 NotImplementedError

with expectNotImplementedError():
  api.jit(square_add_prim)(2., 10.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)

Found expected exception:
Traceback (most recent call last):
  File "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/1876526605.py", line 2, in <module>
    api.jit(square_add_prim)(2., 10.)
  File "/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 332, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
                                                                ^^^^^^^^^^^^^^^^^^^^
NotImplementedError: Abstract evaluation for 'multiply_add' not implemented

抽象评估规则#

为了对函数进行即时编译(JIT),以及进行其他变换,JAX 首先仅使用参数的形状和类型对其进行抽象评估。这个抽象评估有多个目的:

  • 获取计算中使用的 JAX 原始操作的序列。这个序列将被编译。

  • 计算计算中使用的所有向量和操作的形状和类型。

例如,一个具有 3 个元素的向量的抽象形式可能是 ShapedArray(float32[3]),或 ConcreteArray([1., 2., 3.])。在后一种情况下,JAX 使用实际的具体值作为抽象值进行封装。

from jax import core
@trace("multiply_add_abstract_eval")
def multiply_add_abstract_eval(xs, ys, zs):
  """Abstract evaluation of the primitive.

  This function does not need to be JAX traceable. It will be invoked with
  abstractions of the actual arguments.
  Args:
    xs, ys, zs: abstractions of the arguments.
  Result:
    a ShapedArray for the result of the primitive.
  """
  assert xs.shape == ys.shape
  assert xs.shape == zs.shape
  return core.ShapedArray(xs.shape, xs.dtype)

# 现在我们使用JAX注册抽象评估。
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)
<function __main__.multiply_add_abstract_eval(xs, ys, zs)>

如果我们重新尝试 JIT 编译,我们可以看到抽象评估的进展,但是我们遇到了另一个错误,关于缺少实际的 XLA 编译规则:

with expectNotImplementedError():
  api.jit(square_add_prim)(2., 10.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
    call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>

Found expected exception:
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform cpu

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/1876526605.py", line 2, in <module>
    api.jit(square_add_prim)(2., 10.)
  File "/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 332, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
                                                                ^^^^^^^^^^^^^^^^^^^^
NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform cpu

XLA 编译规则#

JAX 编译通过将每个原语编译成 XLA 操作图来工作。

这是向 JAX 添加新功能的最大障碍,因为 XLA 操作集是有限的,而 JAX 已经为其中大部分定义了预定义的原语。然而,XLA 包含一个 CustomCall 操作,可以用来封装使用 C++ 定义的任意功能。

from jax._src.lib.mlir.dialects import hlo
@trace("multiply_add_lowering")
def multiply_add_lowering(ctx, xc, yc, zc):
  """XLA 编译原语。

给定每个参数的 mlir.ir.Value,返回函数结果的 mlir.ir.Values。

不需要是 JAX 可追踪的函数。
  """
  return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result]

# 现在我们将降低规则注册到JAX中。
# 关于GPU,请参阅[GPU的自定义操作](https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html)
# 待办事项:TPU?
from jax.interpreters import mlir
mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu')
<function __main__.multiply_add_lowering(ctx, xc, yc, zc)>

现在我们成功地实现了即时编译(JIT)。请注意,JAX 首先抽象地评估函数,这触发了 multiply_add_abstract_eval 函数,然后编译它所遇到的原语集合,包括 multiply_add。此时,JAX 调用 multiply_add_xla_translation

assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
    call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x1180d1250>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x118100330>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x1181003b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x118100370>, backend_or_name=<jaxlib.xla_extension.Client object at 0x10f99e200>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None, mesh_shape=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x1181000d0>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x10fd04080>: loc(callsite("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("square_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":15:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("<lambda>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/1509938778.py":1:28) at callsite("<module>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/1509938778.py":1:7) at callsite("InteractiveShell.run_code"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3577:20) at callsite("InteractiveShell.run_ast_nodes"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3517:19) at callsite("InteractiveShell.run_cell_async"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3334:29) at "_pseudo_sync_runner"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/async_helpers.py":128:8)))))))))))}, location_cache={(<code object multiply_add_prim at 0x10fb7a830, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py", line 4>, 54): loc("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9)), (<code object func_wrapper at 0x1088b4db0, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py", line 45>, 98): loc("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12)), (<code object square_add_prim at 0x10fbee5d0, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py", line 12>, 32): loc("square_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":15:9)), (<code object <lambda> at 0x10fc23bb0, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/1509938778.py", line 1>, 30): loc("<lambda>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/1509938778.py":1:28)), (<code object <module> at 0x10fc3e120, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/1509938778.py", line 1>, 54): loc("<module>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/1509938778.py":1:7)), (<code object run_code at 0x105ef8bc0, file "/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3541>, 202): loc("InteractiveShell.run_code"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3577:20)), (<code object run_ast_nodes at 0x129139000, file "/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3418>, 1258): loc("InteractiveShell.run_ast_nodes"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3517:19)), (<code object run_cell_async at 0x12913c600, file "/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3183>, 1858): loc("InteractiveShell.run_cell_async"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3334:29)), (<code object _pseudo_sync_runner at 0x10763f630, file "/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 119>, 30): loc("_pseudo_sync_runner"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/async_helpers.py":128:8))}, canonical_name_cache={'/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py', '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py', '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/1509938778.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/1509938778.py', '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py': '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py', '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/async_helpers.py': '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/async_helpers.py'}, is_user_file_cache={'/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/source_info_util.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/core.py': False, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py': True, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py': True, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/1509938778.py': True, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/linear_util.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/profiler.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/pjit.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/traceback_util.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py': True, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/async_helpers.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x118101ed0>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None,threefry_partitionable=False),xla_metadata={}, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x10fbaffb0>]

下面是jit的另一个用法,我们仅针对第一个参数进行编译。请注意,square_add_prim的第二个参数是具体的,这导致在multiply_add_abstract_eval的第三个参数是ConcreteArray。我们看到,multiply_add_abstract_eval可以与ShapedArrayConcreteArray一起使用。

assert api.jit(lambda x, y: square_add_prim(x, y),
               static_argnums=1)(2., 10.) == 14.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 10.0)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 10.0)
    call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x1180d11f0>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x10fbfe9f0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x10fbfd170>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x10fbfe130>, backend_or_name=<jaxlib.xla_extension.Client object at 0x10f99e200>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None, mesh_shape=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x11809b310>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x10f277a60>: loc(callsite("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("square_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":15:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("<lambda>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3805016149.py":1:28) at callsite("<module>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3805016149.py":1:7) at callsite("InteractiveShell.run_code"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3577:20) at callsite("InteractiveShell.run_ast_nodes"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3517:19) at callsite("InteractiveShell.run_cell_async"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3334:29) at "_pseudo_sync_runner"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/async_helpers.py":128:8)))))))))))}, location_cache={(<code object multiply_add_prim at 0x10fb7a830, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py", line 4>, 54): loc("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9)), (<code object func_wrapper at 0x1088b4db0, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py", line 45>, 98): loc("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12)), (<code object square_add_prim at 0x10fbee5d0, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py", line 12>, 32): loc("square_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":15:9)), (<code object <lambda> at 0x10fc23c90, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3805016149.py", line 1>, 30): loc("<lambda>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3805016149.py":1:28)), (<code object <module> at 0x10fc3e890, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3805016149.py", line 1>, 58): loc("<module>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3805016149.py":1:7)), (<code object run_code at 0x105ef8bc0, file "/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3541>, 202): loc("InteractiveShell.run_code"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3577:20)), (<code object run_ast_nodes at 0x129139000, file "/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3418>, 1258): loc("InteractiveShell.run_ast_nodes"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3517:19)), (<code object run_cell_async at 0x12913c600, file "/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3183>, 1858): loc("InteractiveShell.run_cell_async"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3334:29)), (<code object _pseudo_sync_runner at 0x10763f630, file "/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 119>, 30): loc("_pseudo_sync_runner"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/async_helpers.py":128:8))}, canonical_name_cache={'/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py', '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py', '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3805016149.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3805016149.py', '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py': '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py', '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/async_helpers.py': '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/async_helpers.py'}, is_user_file_cache={'/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/source_info_util.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/core.py': False, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py': True, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py': True, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3805016149.py': True, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/linear_util.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/profiler.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/pjit.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/traceback_util.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py': True, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/async_helpers.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x1180d5d10>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None,threefry_partitionable=False),xla_metadata={}, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%0 = "stablehlo.constant"() <{value = dense<1.000000e+01> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x10fbe5230>]

正向微分#

JAX 实现了正向微分的形式为雅可比-向量积(查看 JAX 自动微分食谱)。

如果我们现在尝试计算 jvp 函数,我们会得到一个错误,因为我们还没有告诉 JAX 如何对 multiply_add 原语进行微分。

# 第二个参数 `(2., 10.)` 是参数值
# 我们在哪里评估雅可比矩阵,而第三个 `(1., 1.)`
# 是这些自变量的切线值。
with expectNotImplementedError():
  api.jvp(square_add_prim, (2., 10.), (1., 1.))
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
  call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)

Found expected exception:
Traceback (most recent call last):
  File "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/1807567232.py", line 5, in <module>
    api.jvp(square_add_prim, (2., 10.), (1., 1.))
  File "/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/api.py", line 1922, in jvp
    return _jvp(lu.wrap_init(fun), primals, tangents, has_aux=has_aux)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/api.py", line 1951, in _jvp
    out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
NotImplementedError: Differentiation rule for 'multiply_add' not implemented
from jax.interpreters import ad


@trace("multiply_add_value_and_jvp")
def multiply_add_value_and_jvp(arg_values, arg_tangents):
  """Evaluates the primal output and the tangents (Jacobian-vector product).

  Given values of the arguments and perturbation of the arguments (tangents),
  compute the output of the primitive and the perturbation of the output.

  This method must be JAX-traceable. JAX may invoke it with abstract values
  for the arguments and tangents.

  Args:
    arg_values: a tuple of arguments
    arg_tangents: a tuple with the tangents of the arguments. The tuple has
      the same length as the arg_values. Some of the tangents may also be the
      special value ad.Zero to specify a zero tangent.
  Returns:
     a pair of the primal output and the tangent.
  """
  x, y, z = arg_values
  xt, yt, zt = arg_tangents
  _trace("Primal evaluation:")
  # 现在,我们已经得到了一个可被JAX追踪的输出计算过程。
  # 通常,我们可以直接使用ma原语本身来计算原始输出。
  primal_out = multiply_add_prim(x, y, z)

  _trace("Tangent evaluation:")
  # 我们必须采用一种JAX可追踪的方式来计算切线。事实证明,
  # 输出切线可以计算为 (xt * y + x * yt + zt)。
  # which we can implement in a JAX-traceable way using the same "multiply_add_prim" primitive.

  # 我们的确需要特别处理零。在这里,我们只需将其转换为
  # proper tensor of 0s (of the same shape as 'x').
  # 另一种方法是检查是否为零并进行代数运算。
  # 输出切线计算的简化。
  def make_zero(tan):
    return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan

  output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))
  return (primal_out, output_tangent)

# 在JAX中注册前向微分规则
ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp
# 切线方程为:xt*y + x*yt + zt = 1.*2. + 2.*1. + 1. = 5.
assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.)
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
  call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
    call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (1.0, 1.0, 1.0))
      Primal evaluation:
      call multiply_add_prim(2.0, 2.0, 10.0)
        call multiply_add_impl(2.0, 2.0, 10.0)
        |<- multiply_add_impl = 14.0
      |<- multiply_add_prim = 14.0
      Tangent evaluation:
      call multiply_add_prim(2.0, 1.0, 1.0)
        call multiply_add_impl(2.0, 1.0, 1.0)
        |<- multiply_add_impl = 3.0
      |<- multiply_add_prim = 3.0
      call multiply_add_prim(1.0, 2.0, 3.0)
        call multiply_add_impl(1.0, 2.0, 3.0)
        |<- multiply_add_impl = 5.0
      |<- multiply_add_prim = 5.0
    |<- multiply_add_value_and_jvp = (14.0, 5.0)
  |<- multiply_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
|<- square_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>

为了说明:

  • 为什么 JAX 在 square_add_prim 中使用 ConcreteArray?这里没有抽象评估的过程。

  • 不确定该如何解释 multiply_add_prim 是如何与 ConcreteValue 一起调用的,但我们并没有调用 multiply_add_abstract_eval。

  • 我认为在这里显示 jaxpr 会很有用。

前向微分的即时编译#

我们可以将即时编译应用于前向微分函数:

assert api.jit(lambda arg_values, arg_tangents:
                   api.jvp(square_add_prim, arg_values, arg_tangents))(
         (2., 10.), (1., 1.)) == (14., 5.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
    call multiply_add_value_and_jvp((Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>), (Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>))
      Primal evaluation:
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
      Tangent evaluation:
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
    |<- multiply_add_value_and_jvp = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x11812be30>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x118143230>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x1181432b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x118143270>, backend_or_name=<jaxlib.xla_extension.Client object at 0x10f99e200>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None, mesh_shape=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x1181430d0>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x10f274490>: loc(callsite("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("multiply_add_value_and_jvp"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py":27:15) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("square_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":15:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("<lambda>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py":2:19) at "<module>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py":1:7)))))))))))}, location_cache={(<code object multiply_add_prim at 0x10fb7a830, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py", line 4>, 54): loc("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9)), (<code object func_wrapper at 0x1088b4db0, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py", line 45>, 98): loc("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12)), (<code object multiply_add_value_and_jvp at 0x108d53130, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py", line 4>, 88): loc("multiply_add_value_and_jvp"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py":27:15)), (<code object square_add_prim at 0x10fbee5d0, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py", line 12>, 32): loc("square_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":15:9)), (<code object <lambda> at 0x10fb7a630, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py", line 1>, 64): loc("<lambda>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py":2:19)), (<code object <module> at 0x10fc3e890, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py", line 1>, 54): loc("<module>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py":1:7))}, canonical_name_cache={'/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py', '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py', '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py', '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py'}, is_user_file_cache={'/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/source_info_util.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/core.py': False, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py': True, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py': True, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py': True, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/interpreters/ad.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/linear_util.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/api.py': False, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py': True, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/profiler.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/pjit.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/traceback_util.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x118144890>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None,threefry_partitionable=False),xla_metadata={}, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x118144ab0>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x11812be30>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x118143230>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x1181432b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x118143270>, backend_or_name=<jaxlib.xla_extension.Client object at 0x10f99e200>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None, mesh_shape=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x1181430d0>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x10f274490>: loc(callsite("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("multiply_add_value_and_jvp"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py":27:15) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("square_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":15:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("<lambda>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py":2:19) at "<module>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py":1:7))))))))))), <jaxlib.xla_extension.Traceback object at 0x10f273d50>: loc(callsite("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("multiply_add_value_and_jvp"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py":41:55) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("square_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":15:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("<lambda>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py":2:19) at "<module>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py":1:7)))))))))))}, location_cache={(<code object multiply_add_prim at 0x10fb7a830, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py", line 4>, 54): loc("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9)), (<code object func_wrapper at 0x1088b4db0, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py", line 45>, 98): loc("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12)), (<code object multiply_add_value_and_jvp at 0x108d53130, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py", line 4>, 88): loc("multiply_add_value_and_jvp"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py":27:15)), (<code object square_add_prim at 0x10fbee5d0, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py", line 12>, 32): loc("square_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":15:9)), (<code object <lambda> at 0x10fb7a630, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py", line 1>, 64): loc("<lambda>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py":2:19)), (<code object <module> at 0x10fc3e890, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py", line 1>, 54): loc("<module>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py":1:7)), (<code object multiply_add_value_and_jvp at 0x108d53130, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py", line 4>, 232): loc("multiply_add_value_and_jvp"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py":41:55))}, canonical_name_cache={'/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py', '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py', '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py', '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py'}, is_user_file_cache={'/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/source_info_util.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/core.py': False, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py': True, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py': True, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py': True, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/interpreters/ad.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/linear_util.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/api.py': False, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py': True, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/profiler.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/pjit.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/traceback_util.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x118144e90>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None,threefry_partitionable=False),xla_metadata={}, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 2), Value(<block argument> of type 'tensor<f32>' at index: 3))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x118144ef0>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x11812be30>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x118143230>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x1181432b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x118143270>, backend_or_name=<jaxlib.xla_extension.Client object at 0x10f99e200>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None, mesh_shape=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x1181430d0>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x10f274490>: loc(callsite("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("multiply_add_value_and_jvp"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py":27:15) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("square_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":15:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("<lambda>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py":2:19) at "<module>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py":1:7))))))))))), <jaxlib.xla_extension.Traceback object at 0x10f273d50>: loc(callsite("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("multiply_add_value_and_jvp"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py":41:55) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("square_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":15:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("<lambda>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py":2:19) at "<module>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py":1:7))))))))))), <jaxlib.xla_extension.Traceback object at 0x10f273f70>: loc(callsite("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("multiply_add_value_and_jvp"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py":41:19) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("square_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":15:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("<lambda>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py":2:19) at "<module>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py":1:7)))))))))))}, location_cache={(<code object multiply_add_prim at 0x10fb7a830, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py", line 4>, 54): loc("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9)), (<code object func_wrapper at 0x1088b4db0, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py", line 45>, 98): loc("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12)), (<code object multiply_add_value_and_jvp at 0x108d53130, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py", line 4>, 88): loc("multiply_add_value_and_jvp"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py":27:15)), (<code object square_add_prim at 0x10fbee5d0, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py", line 12>, 32): loc("square_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":15:9)), (<code object <lambda> at 0x10fb7a630, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py", line 1>, 64): loc("<lambda>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py":2:19)), (<code object <module> at 0x10fc3e890, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py", line 1>, 54): loc("<module>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py":1:7)), (<code object multiply_add_value_and_jvp at 0x108d53130, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py", line 4>, 232): loc("multiply_add_value_and_jvp"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py":41:55)), (<code object multiply_add_value_and_jvp at 0x108d53130, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py", line 4>, 246): loc("multiply_add_value_and_jvp"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py":41:19))}, canonical_name_cache={'/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py', '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py', '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py', '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py'}, is_user_file_cache={'/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/source_info_util.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/core.py': False, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py': True, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py': True, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py': True, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/interpreters/ad.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/linear_util.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/api.py': False, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3503505880.py': True, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/profiler.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/pjit.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/traceback_util.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[])], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x118145250>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None,threefry_partitionable=False),xla_metadata={}, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 2), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%3 = "stablehlo.add"(%2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x10b9ad330>]

注意到我们首先抽象地评估 multiply_add_value_and_jvp,这又抽象地评估了原始值和切线评估(总共调用了 3 次 ma 原语)。然后我们编译这 3 次原语的出现。

反向微分#

如果我们现在尝试使用反向微分,我们会看到 JAX 首先使用 multiply_add_value_and_jvp 来计算抽象值的前向微分,但随后遇到了 NotImplementedError

在计算反向微分时,JAX 首先对前向微分代码 multiply_add_value_and_jvp 进行抽象评估,以获得计算输出切线的原始操作的跟踪。需要注意的是,JAX 在微分点使用具体值而在切线使用抽象值来进行这种抽象评估。还要注意,JAX 为 ma 的第三个参数对应的切线使用了特殊的抽象切线值 Zero。这反映了我们不会对 square_add_prim 的第二个参数进行微分,而这个参数流向 multiply_add_prim 的第三个参数。

还要注意,在切线的抽象评估过程中,我们将值 0.0 作为第三个参数的切线。这是由于在 multiply_add_value_and_jvp 的定义中使用了 make_zero 函数。

# 这是对square_add_prim函数的第一个参数进行反向微分。
with expectNotImplementedError():
  api.grad(square_add_prim)(2., 10.)
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
  call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
    call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
      Primal evaluation:
      call multiply_add_prim(2.0, 2.0, 10.0)
        call multiply_add_impl(2.0, 2.0, 10.0)
        |<- multiply_add_impl = 14.0
      |<- multiply_add_prim = 14.0
      Tangent evaluation:
      call multiply_add_prim(2.0, Traced<ShapedArray(float32[], weak_type=True)>, 0.0)
        call multiply_add_abstract_eval(ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[], weak_type=True), ConcreteArray(0.0, dtype=float32, weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, 2.0, Traced<ShapedArray(float32[])>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
    |<- multiply_add_value_and_jvp = (14.0, Traced<ShapedArray(float32[])>)
  |<- multiply_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
|<- square_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
Found expected exception:
Traceback (most recent call last):
  File "/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/interpreters/ad.py", line 283, in get_primitive_transpose
    return primitive_transposes[p]
           ~~~~~~~~~~~~~~~~~~~~^^^
KeyError: multiply_add

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'multiply_add' not implemented

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2630988720.py", line 3, in <module>
    api.grad(square_add_prim)(2., 10.)
  File "/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/api.py", line 633, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'multiply_add' not implemented

上述错误是因为JAX缺少一个部分,无法使用前向微分代码来计算反向微分。

转置#

如上所述,在计算反向微分时,JAX 会获取一个原语的轨迹,该轨迹使用前向微分计算切线。然后,JAX 抽象地向后解释这个轨迹,并对每个原语应用一个 转置 规则。

为了理解发生了什么,暂且考虑一个更简单的函数示例 “f(x, y) = x * y + y”。假设我们需要在点 (2., 4.) 处进行微分。JAX 将从输入的 xtyt 的切线生成 ft 的以下 JVP 切线计算:

   a = xt * 4.
   b = 2. * yt
   c = a + b
   ft = c + yt

通过构造,切线计算始终在线性输入切线中。切线计算中可能出现的唯一非线性运算符是乘法,但此时其中一个操作数是常量。

JAX 将通过逆向处理 JVP 计算来生成反向微分计算。对于切线计算中的每个操作,它使用该操作的结果的余切累积使用该操作的变量的余切:

  # 初始化输入和中间变量的余切
  xct = yct = act = bct = cct = 0.
  # 初始化输出的余切
  fct = 1.
  # 处理 "ft = c + yt"
  cct += fct
  yct += fct
  # 处理 "c = a + b"
  act += cct
  bct += cct
  # 处理 "b = 2. * yt"
  yct += 2. * bct
  # 处理 "a = xt * 4."
  xct += act * 4.

可以验证,这个计算会生成 xct = 4.yct = 3.,这就是函数 f 的偏导数。

JAX 知道在 JVP 计算中可能出现的每个原语如何进行转置。从概念上讲,如果原语 p(x, y, z) 对于常量值 x 线性地依赖于参数 yz,例如,p(x, y, z) = y*cy + z*cz,那么该原语的转置为:

p_transpose(out_ct, x, _, _) = (None, out_ct*cy, out_ct*cz)

请注意,p_transpose 接收原语输出的余切和对应于原语每个参数的值。对于线性参数,转置得到一个未定义的 _ 值,而对于其他参数,则得到实际的常量。转置返回原语每个参数的余切值,对于常量参数返回的值为 None

特别地,

 add_transpose(out_ct, _, _) = (out_ct, out_ct)
 mult_transpose(out_ct, x, _) = (None, x * out_ct)
 mult_transpose(out_ct, _, y) = (out_ct * y, None)
@trace("multiply_add_transpose")
def multiply_add_transpose(ct, x, y, z):
  """Evaluates the transpose of a linear primitive.

  This method is only used when computing the backward gradient following
  value_and_jvp, and is only needed for primitives that are used in the JVP
  calculation for some other primitive. We need transposition for multiply_add_prim,
  because we have used multiply_add_prim in the computation of the output_tangent in
  multiply_add_value_and_jvp.

  In our case, multiply_add is not a linear primitive. However, it is used linearly
  w.r.t. tangents in multiply_add_value_and_jvp:
       output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))

  Always one of the first two multiplicative arguments is a constant.

  Args:
      ct: the cotangent of the output of the primitive.
      x, y, z: values of the arguments. The arguments that are used linearly
        get an ad.UndefinedPrimal value. The other arguments get a constant
        value.
  Returns:
      a tuple with the cotangent of the inputs, with the value None
      corresponding to the constant arguments.
  """
  if not ad.is_undefined_primal(x):
    # 这种使用 multiply_add 的方式是与一个常量结合的。 "x"
    assert ad.is_undefined_primal(y)
    ct_y = ad.Zero(y.aval) if type(ct) is ad.Zero else multiply_add_prim(x, ct, lax.zeros_like_array(x))
    res = None, ct_y, ct
  else:
    # 这种使用 multiply_add 的方式是与一个常量结合的。 "y"
    assert ad.is_undefined_primal(x)
    ct_x = ad.Zero(x.aval) if type(ct) is ad.Zero else multiply_add_prim(ct, y, lax.zeros_like_array(y))
    res = ct_x, None, ct
  return res


ad.primitive_transposes[multiply_add_p] = multiply_add_transpose

现在我们可以完成grad的运行:

assert api.grad(square_add_prim)(2., 10.) == 4.
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
  call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
    call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
      Primal evaluation:
      call multiply_add_prim(2.0, 2.0, 10.0)
        call multiply_add_impl(2.0, 2.0, 10.0)
        |<- multiply_add_impl = 14.0
      |<- multiply_add_prim = 14.0
      Tangent evaluation:
      call multiply_add_prim(2.0, Traced<ShapedArray(float32[], weak_type=True)>, 0.0)
        call multiply_add_abstract_eval(ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[], weak_type=True), ConcreteArray(0.0, dtype=float32, weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, 2.0, Traced<ShapedArray(float32[])>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
    |<- multiply_add_value_and_jvp = (14.0, Traced<ShapedArray(float32[])>)
  |<- multiply_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
|<- square_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
call multiply_add_transpose(1.0, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), 2.0, UndefinedPrimal(ShapedArray(float32[])))
  call multiply_add_prim(1.0, 2.0, 0.0)
    call multiply_add_impl(1.0, 2.0, 0.0)
    |<- multiply_add_impl = 2.0
  |<- multiply_add_prim = 2.0
|<- multiply_add_transpose = (2.0, None, 1.0)
call multiply_add_transpose(1.0, 2.0, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), 0.0)
  call multiply_add_prim(2.0, 1.0, 0.0)
    call multiply_add_impl(2.0, 1.0, 0.0)
    |<- multiply_add_impl = 2.0
  |<- multiply_add_prim = 2.0
|<- multiply_add_transpose = (None, 2.0, 1.0)

注意到对 multiply_add_transpose 的两次调用。它们对应于在 multiply_add_value_and_jvp 中计算 output_tangent 时对 multiply_add_prim 的两次使用。第一次调用转置对应于 multiply_add_prim 的最后一次使用:multiply_add_prim(xt, y, ...),其中 y 是常数 2.0。

反向微分的即时编译#

注意到,multiply_add_value_and_jvp 的抽象评估仅使用抽象值,而在没有即时编译的情况下,我们使用了 ConcreteArray

assert api.jit(api.grad(square_add_prim))(2., 10.) == 4.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
    call multiply_add_value_and_jvp((Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
      Primal evaluation:
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
      Tangent evaluation:
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
    |<- multiply_add_value_and_jvp = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_transpose(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, UndefinedPrimal(ShapedArray(float32[])))
  call multiply_add_prim(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
    call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- multiply_add_transpose = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, None, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_transpose(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
    call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[]), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- multiply_add_transpose = (None, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x118168b30>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x1181767f0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x118176930>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x118176830>, backend_or_name=<jaxlib.xla_extension.Client object at 0x10f99e200>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None, mesh_shape=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x118175750>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x10f27bdb0>: loc(callsite("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("multiply_add_value_and_jvp"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py":41:19) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("square_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":15:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("<module>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/1158933884.py":1:7) at "InteractiveShell.run_code"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3577:20)))))))))))}, location_cache={(<code object multiply_add_prim at 0x10fb7a830, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py", line 4>, 54): loc("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9)), (<code object func_wrapper at 0x1088b4db0, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py", line 45>, 98): loc("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12)), (<code object multiply_add_value_and_jvp at 0x108d53130, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py", line 4>, 246): loc("multiply_add_value_and_jvp"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py":41:19)), (<code object square_add_prim at 0x10fbee5d0, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py", line 12>, 32): loc("square_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":15:9)), (<code object <module> at 0x10fa5fcc0, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/1158933884.py", line 1>, 90): loc("<module>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/1158933884.py":1:7)), (<code object run_code at 0x105ef8bc0, file "/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3541>, 202): loc("InteractiveShell.run_code"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3577:20))}, canonical_name_cache={'/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py', '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py', '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py', '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/1158933884.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/1158933884.py', '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py': '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/source_info_util.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/core.py': False, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py': True, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py': True, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py': True, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/interpreters/ad.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/linear_util.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/profiler.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/api.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/traceback_util.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/pjit.py': False, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/1158933884.py': True, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='transpose'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x118177d10>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None,threefry_partitionable=False),xla_metadata={}, platforms=None), Value(%0 = "stablehlo.constant"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%1 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x10e0335f0>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x118168b30>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x1181767f0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x118176930>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x118176830>, backend_or_name=<jaxlib.xla_extension.Client object at 0x10f99e200>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None, mesh_shape=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x118175750>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x10f27bdb0>: loc(callsite("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("multiply_add_value_and_jvp"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py":41:19) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("square_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":15:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("<module>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/1158933884.py":1:7) at "InteractiveShell.run_code"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3577:20))))))))))), <jaxlib.xla_extension.Traceback object at 0x10f274490>: loc(callsite("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("multiply_add_value_and_jvp"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py":41:55) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("square_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":15:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("<module>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/1158933884.py":1:7) at "InteractiveShell.run_code"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3577:20)))))))))))}, location_cache={(<code object multiply_add_prim at 0x10fb7a830, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py", line 4>, 54): loc("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9)), (<code object func_wrapper at 0x1088b4db0, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py", line 45>, 98): loc("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12)), (<code object multiply_add_value_and_jvp at 0x108d53130, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py", line 4>, 246): loc("multiply_add_value_and_jvp"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py":41:19)), (<code object square_add_prim at 0x10fbee5d0, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py", line 12>, 32): loc("square_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":15:9)), (<code object <module> at 0x10fa5fcc0, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/1158933884.py", line 1>, 90): loc("<module>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/1158933884.py":1:7)), (<code object run_code at 0x105ef8bc0, file "/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3541>, 202): loc("InteractiveShell.run_code"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3577:20)), (<code object multiply_add_value_and_jvp at 0x108d53130, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py", line 4>, 232): loc("multiply_add_value_and_jvp"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py":41:55))}, canonical_name_cache={'/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py', '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py', '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py', '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/1158933884.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/1158933884.py', '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py': '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/source_info_util.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/core.py': False, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py': True, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py': True, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3184771669.py': True, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/interpreters/ad.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/linear_util.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/profiler.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/api.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/traceback_util.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/pjit.py': False, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/1158933884.py': True, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='transpose'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[]), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x11817c090>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None,threefry_partitionable=False),xla_metadata={}, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%4 = "stablehlo.constant"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>), Value(%5 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x118177d70>]

批处理#

批处理转换将逐点计算转换为向量计算。如果我们现在尝试它,我们会得到一个 NotImplementedError

# 这些参数是两个向量,而非两个标量。
with expectNotImplementedError():
  api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),
                                               np.array([10., 20.]))
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
  call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)

Found expected exception:
Traceback (most recent call last):
  File "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/1781857283.py", line 3, in <module>
    api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),
  File "/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/api.py", line 1227, in vmap_f
    out_flat = batching.batch(
               ^^^^^^^^^^^^^^^
NotImplementedError: Batching rule for 'multiply_add' not implemented

我们需要告诉 JAX 如何评估原语的批处理版本。在这种特殊情况下,multiply_add_prim 已经可以对任何维度的输入向量逐点操作。因此,批处理版本可以使用相同的 multiply_add_prim 实现。

from jax.interpreters import batching


@trace("multiply_add_batch")
def multiply_add_batch(vector_arg_values, batch_axes):
  """Computes the batched version of the primitive.

  This must be a JAX-traceable function.

  Since the multiply_add primitive already operates pointwise on arbitrary
  dimension tensors, to batch it we can use the primitive itself. This works as
  long as both the inputs have the same dimensions and are batched along the
  same axes. The result is batched along the axis that the inputs are batched.

  Args:
    vector_arg_values: a tuple of two arguments, each being a tensor of matching
      shape.
    batch_axes: the axes that are being batched. See vmap documentation.
  Returns:
    a tuple of the result, and the result axis that was batched.
  """
  assert batch_axes[0] == batch_axes[1]
  assert batch_axes[0] == batch_axes[2]
  _trace("Using multiply_add to compute the batch:")
  res = multiply_add_prim(*vector_arg_values)
  return res, batch_axes[0]


batching.primitive_batchers[multiply_add_p] = multiply_add_batch
assert np.allclose(api.vmap(square_add_prim, in_axes=0, out_axes=0)(
  np.array([2., 3.]),
  np.array([10., 20.])),
  [14., 29.])
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
  call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
    call multiply_add_batch(([2. 3.], [2. 3.], [10. 20.]), (0, 0, 0))
      Using multiply_add to compute the batch:
      call multiply_add_prim([2. 3.], [2. 3.], [10. 20.])
        call multiply_add_impl([2. 3.], [2. 3.], [10. 20.])
        |<- multiply_add_impl = [14. 29.]
      |<- multiply_add_prim = [14. 29.]
    |<- multiply_add_batch = ([14. 29.], 0)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>

批处理的即时编译#

assert np.allclose(api.jit(api.vmap(square_add_prim, in_axes=0, out_axes=0))
                    (np.array([2., 3.]),
                     np.array([10., 20.])),
                    [14., 29.])
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
  call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
    call multiply_add_batch((Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>), (0, 0, 0))
      Using multiply_add to compute the batch:
      call multiply_add_prim(Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>)
        call multiply_add_abstract_eval(ShapedArray(float32[2]), ShapedArray(float32[2]), ShapedArray(float32[2]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[2])
      |<- multiply_add_prim = Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>
    |<- multiply_add_batch = (Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, 0)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x118169190>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x10fc1d530>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x11817cd70>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x11817c1b0>, backend_or_name=<jaxlib.xla_extension.Client object at 0x10f99e200>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None, mesh_shape=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x11817cc50>, all_default_mem_kind=True, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x10dffbf80>: loc(callsite("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("multiply_add_batch"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3297882078.py":25:8) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("square_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":15:9) at callsite("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12) at callsite("<module>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3465019417.py":1:19) at "InteractiveShell.run_code"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3577:20)))))))))))}, location_cache={(<code object multiply_add_prim at 0x10fb7a830, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py", line 4>, 54): loc("multiply_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":10:9)), (<code object func_wrapper at 0x1088b4db0, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py", line 45>, 98): loc("trace.<locals>.trace_func.<locals>.func_wrapper"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py":48:12)), (<code object multiply_add_batch at 0x10f989d10, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3297882078.py", line 4>, 126): loc("multiply_add_batch"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3297882078.py":25:8)), (<code object square_add_prim at 0x10fbee5d0, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py", line 12>, 32): loc("square_add_prim"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py":15:9)), (<code object <module> at 0x10f97e3a0, file "/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3465019417.py", line 1>, 204): loc("<module>"("/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3465019417.py":1:19)), (<code object run_code at 0x105ef8bc0, file "/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3541>, 202): loc("InteractiveShell.run_code"("/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py":3577:20))}, canonical_name_cache={'/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py', '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py', '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3297882078.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3297882078.py', '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3465019417.py': '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3465019417.py', '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py': '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/source_info_util.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/core.py': False, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2492108615.py': True, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/2243015332.py': True, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3297882078.py': True, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/interpreters/batching.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/linear_util.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/api.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/traceback_util.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/profiler.py': False, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/pjit.py': False, '/var/folders/xc/cwj7_pwj6lb0lkpyjtcbm7y80000gn/T/ipykernel_22763/3465019417.py': True, '/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False, export_ignore_forward_compatibility=False)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='vmap'))), primitive=multiply_add, avals_in=[ShapedArray(float32[2]), ShapedArray(float32[2]), ShapedArray(float32[2])], avals_out=[ShapedArray(float32[2])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x11817dc50>, tokens_out=None, axis_size_env=None, dim_var_values=[], jaxpr_eqn_ctx=JaxprEqnContext(compute_type=None,threefry_partitionable=False),xla_metadata={}, platforms=None), Value(<block argument> of type 'tensor<2xf32>' at index: 0), Value(<block argument> of type 'tensor<2xf32>' at index: 0), Value(<block argument> of type 'tensor<2xf32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x1180f7770>]