使用 pytrees#

JAX 内置支持看起来像数组字典(dicts)的对象,或者字典列表的列表,或者其他嵌套结构——在 JAX 中,这些被称为 pytrees。本节将解释如何使用它们,提供有用的代码示例,并指出常见的“陷阱”和模式。

什么是 pytree?#

pytree 是一种类似容器的结构,由类似容器的 Python 对象构建——“叶子” pytree 和/或其他 pytree。pytree 可以包含列表、元组和字典。叶子是任何不是 pytree 的东西,例如数组,但单个叶子也是一个 pytree。

在机器学习 (ML) 的背景下,一个 pytree 可以包含:

  • 模型参数

  • 数据集条目

  • 强化学习代理观察

在处理数据集时,你经常会遇到 pytrees(例如字典的列表的列表)。

下面是一个简单的 pytree 示例。在 JAX 中,你可以使用 jax.tree.leaves() 从树中提取扁平化的叶子,如下所示:

import jax
import jax.numpy as jnp

example_trees = [
    [1, 'a', object()],
    (1, (2, 3), ()),
    [1, {'k1': 2, 'k2': (3, 4)}, 5],
    {'a': 2, 'b': (2, 3)},
    jnp.array([1, 2, 3]),
]

# Print how many leaves the pytrees have.
for pytree in example_trees:
  # This `jax.tree.leaves()` method extracts the flattened leaves from the pytrees.
  leaves = jax.tree.leaves(pytree)
  print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")
[1, 'a', <object object at 0x11016b9d0>]      has 3 leaves: [1, 'a', <object object at 0x11016b9d0>]
(1, (2, 3), ())                               has 3 leaves: [1, 2, 3]
[1, {'k1': 2, 'k2': (3, 4)}, 5]               has 5 leaves: [1, 2, 3, 4, 5]
{'a': 2, 'b': (2, 3)}                         has 3 leaves: [2, 2, 3]
Array([1, 2, 3], dtype=int32)                 has 1 leaves: [Array([1, 2, 3], dtype=int32)]

任何由类似容器的Python对象构建的树状结构都可以在JAX中被视为pytree。如果类在pytree注册表中,则被认为是类似容器的,默认情况下包括列表、元组和字典。任何类型在pytree容器注册表中的对象将被视为树中的叶节点。

pytree 注册表可以通过注册类来扩展,以包含用户定义的容器类,这些类指定了如何展平树的函数;请参见下面的 自定义 pytree 节点

常见的 pytree 函数#

JAX 提供了许多用于操作 pytrees 的工具。这些工具可以在 jax.tree_util 子包中找到;为了方便起见,其中许多工具在 jax.tree 模块中有别名。

常用功能:jax.tree.map#

最常用的 pytree 函数是 jax.tree.map()。它的工作原理类似于 Python 的原生 map,但透明地操作整个 pytree。

这是一个例子:

list_of_lists = [
    [1, 2, 3],
    [1, 2],
    [1, 2, 3, 4]
]

jax.tree.map(lambda x: x*2, list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

jax.tree.map() 还允许将 N-ary 函数映射到多个参数上。例如:

another_list_of_lists = list_of_lists
jax.tree.map(lambda x, y: x+y, list_of_lists, another_list_of_lists)
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

在使用 jax.tree.map() 时,如果使用多个参数,输入的结构必须完全匹配。也就是说,列表必须具有相同数量的元素,字典必须具有相同的键等。

使用 jax.tree.map 处理机器学习模型参数的示例#

这个例子展示了在训练一个简单的 多层感知器 (MLP) 时,pytree 操作如何有用。

开始定义初始模型参数:

import numpy as np

def init_mlp_params(layer_widths):
  params = []
  for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
    params.append(
        dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),
             biases=np.ones(shape=(n_out,))
            )
    )
  return params

params = init_mlp_params([1, 128, 128, 1])

使用 jax.tree.map() 来检查初始参数的形状:

jax.tree.map(lambda x: x.shape, params)
[{'biases': (128,), 'weights': (1, 128)},
 {'biases': (128,), 'weights': (128, 128)},
 {'biases': (1,), 'weights': (128, 1)}]

接下来,定义用于训练MLP模型的函数:

# Define the forward pass.
def forward(params, x):
  *hidden, last = params
  for layer in hidden:
    x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
  return x @ last['weights'] + last['biases']

# Define the loss function.
def loss_fn(params, x, y):
  return jnp.mean((forward(params, x) - y) ** 2)

# Set the learning rate.
LEARNING_RATE = 0.0001

# Using the stochastic gradient descent, define the parameter update function.
# Apply `@jax.jit` for JIT compilation (speed).
@jax.jit
def update(params, x, y):
  # Calculate the gradients with `jax.grad`.
  grads = jax.grad(loss_fn)(params, x, y)
  # Note that `grads` is a pytree with the same structure as `params`.
  # `jax.grad` is one of many JAX functions that has
  # built-in support for pytrees.
  # This is useful - you can apply the SGD update using JAX pytree utilities.
  return jax.tree.map(
      lambda p, g: p - LEARNING_RATE * g, params, grads
  )

自定义 pytree 节点#

本节解释了在 JAX 中如何通过使用 jax.tree_util.register_pytree_node()jax.tree.map() 来扩展将被视为 pytrees(pytree 节点)中 内部节点 的 Python 类型集。

为什么你需要这个?在前面的例子中,pytrees 被展示为列表、元组和字典,其他所有内容都被视为 pytree 叶子。这是因为如果你定义了自己的容器类,它将被视为 pytree 叶子,除非你用 JAX 注册 它。即使你的容器类内部有树,也是如此。例如:

class Special(object):
  def __init__(self, x, y):
    self.x = x
    self.y = y

jax.tree.leaves([
    Special(0, 1),
    Special(2, 4),
])
[<__main__.Special at 0x116a3fd10>, <__main__.Special at 0x115f85950>]

因此,如果你尝试使用 jax.tree.map() 并期望叶子是容器内的元素,你将会遇到错误:

jax.tree.map(lambda x: x + 1,
  [
    Special(0, 1),
    Special(2, 4)
  ])
TypeError: unsupported operand type(s) for +: 'Special' and 'int'

作为一种解决方案,JAX 允许通过全局类型注册表扩展被视为内部 pytree 节点的类型集合。此外,已注册类型的值会被递归遍历。

首先,使用 jax.tree_util.register_pytree_node() 注册一个新类型:

from jax.tree_util import register_pytree_node

class RegisteredSpecial(Special):
  def __repr__(self):
    return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)

def special_flatten(v):
  """Specifies a flattening recipe.

  Params:
    v: The value of the registered type to flatten.
  Returns:
    A pair of an iterable with the children to be flattened recursively,
    and some opaque auxiliary data to pass back to the unflattening recipe.
    The auxiliary data is stored in the treedef for use during unflattening.
    The auxiliary data could be used, for example, for dictionary keys.
  """
  children = (v.x, v.y)
  aux_data = None
  return (children, aux_data)

def special_unflatten(aux_data, children):
  """Specifies an unflattening recipe.

  Params:
    aux_data: The opaque data that was specified during flattening of the
      current tree definition.
    children: The unflattened children

  Returns:
    A reconstructed object of the registered type, using the specified
    children and auxiliary data.
  """
  return RegisteredSpecial(*children)

# Global registration
register_pytree_node(
    RegisteredSpecial,
    special_flatten,    # Instruct JAX what are the children nodes.
    special_unflatten   # Instruct JAX how to pack back into a `RegisteredSpecial`.
)

现在你可以遍历这个特殊的容器结构:

jax.tree.map(lambda x: x + 1,
  [
   RegisteredSpecial(0, 1),
   RegisteredSpecial(2, 4),
  ])
[RegisteredSpecial(x=1, y=2), RegisteredSpecial(x=3, y=5)]

现代 Python 配备了有用的工具,使定义容器更容易。有些可以直接与 JAX 配合使用,但其他一些则需要更多注意。

例如,一个Python NamedTuple 子类不需要注册就可以被视为 pytree 节点类型:

from typing import NamedTuple, Any

class MyOtherContainer(NamedTuple):
  name: str
  a: Any
  b: Any
  c: Any

# NamedTuple subclasses are handled as pytree nodes, so
# this will work out-of-the-box.
jax.tree.leaves([
    MyOtherContainer('Alice', 1, 2, 3),
    MyOtherContainer('Bob', 4, 5, 6)
])
['Alice', 1, 2, 3, 'Bob', 4, 5, 6]

注意,name 字段现在显示为一个叶子节点,因为所有元组元素都是子节点。这就是当你不需要硬性注册类时发生的情况。

Pytrees 和 JAX 变换#

许多 JAX 函数,如 jax.lax.scan(),操作于数组的 pytrees 之上。此外,所有 JAX 函数变换都可以应用于接受输入并产生输出为数组的 pytrees 的函数。

一些 JAX 函数变换接受可选参数,这些参数指定如何处理某些输入或输出值(例如 jax.vmap()in_axesout_axes 参数)。这些参数也可以是 pytrees,并且它们的结构必须与相应参数的 pytree 结构相对应。特别是,为了能够“匹配”这些参数 pytrees 中的叶子与参数 pytrees 中的值,参数 pytrees 通常被限制为参数 pytrees 的树前缀。

例如,如果你将以下输入传递给 jax.vmap()(注意,函数的输入参数被视为一个元组):

vmap(f, in_axes=(a1, {"k1": a2, "k2": a3}))

然后你可以使用以下 in_axes pytree 来指定只有 k2 参数被映射(axis=0),而其余的参数不被映射(axis=None):

vmap(f, in_axes=(None, {"k1": None, "k2": 0}))

可选参数 pytree 结构必须与主输入 pytree 的结构相匹配。然而,可选参数可以可选地指定为“前缀” pytree,这意味着单个叶值可以应用于整个子 pytree。

例如,如果你有与上面相同的 jax.vmap() 输入,但希望仅映射字典参数,你可以使用:

vmap(f, in_axes=(None, 0))  # equivalent to (None, {"k1": 0, "k2": 0})

或者,如果你想将每个参数都映射,你可以写一个单一的叶值,该值应用于整个参数元组 pytree:

vmap(f, in_axes=0)  # equivalent to (0, {"k1": 0, "k2": 0})

这恰好是 jax.vmap() 的默认 in_axes 值。

同样的逻辑适用于其他引用特定输入或输出值的可选参数,例如 jax.vmap() 中的 out_axes

显式键路径#

在 pytree 中,每个叶子都有一个 键路径。叶子的键路径是一个 列表,列表的长度等于叶子在 pytree 中的深度。每个 是一个 可哈希对象,表示对应 pytree 节点类型的索引。键的类型取决于 pytree 节点类型;例如,字典 的键类型与 元组 的键类型不同。

对于内置的 pytree 节点类型,任何 pytree 节点实例的键集合是唯一的。对于由具有此属性的节点组成的 pytree,每个叶子的键路径是唯一的。

JAX 有以下 jax.tree_util.* 方法用于处理键路径:

例如,一个用例是打印与某个叶子值相关的调试信息:

import collections

ATuple = collections.namedtuple("ATuple", ('name'))

tree = [1, {'k1': 2, 'k2': (3, 4)}, ATuple('foo')]
flattened, _ = jax.tree_util.tree_flatten_with_path(tree)

for key_path, value in flattened:
  print(f'Value of tree{jax.tree_util.keystr(key_path)}: {value}')
Value of tree[0]: 1
Value of tree[1]['k1']: 2
Value of tree[1]['k2'][0]: 3
Value of tree[1]['k2'][1]: 4
Value of tree[2].name: foo

为了表示键路径,JAX 为内置的 pytree 节点类型提供了几种默认的键类型,即:

  • SequenceKey(idx: int):用于列表和元组。

  • DictKey(key: Hashable):用于字典。

  • GetAttrKey(name: str):用于 namedtuple 和首选的自定义 pytree 节点(更多内容见下一节)

你可以自由定义你自定义节点的键类型。只要它们的 __str__() 方法也被重写为一个对读者友好的表达式,它们就能与 jax.tree_util.keystr() 一起工作。

for key_path, _ in flattened:
  print(f'Key path of tree{jax.tree_util.keystr(key_path)}: {repr(key_path)}')
Key path of tree[0]: (SequenceKey(idx=0),)
Key path of tree[1]['k1']: (SequenceKey(idx=1), DictKey(key='k1'))
Key path of tree[1]['k2'][0]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=0))
Key path of tree[1]['k2'][1]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=1))
Key path of tree[2].name: (SequenceKey(idx=2), GetAttrKey(name='name'))

常见的 pytree 陷阱#

本节涵盖了在使用 JAX pytrees 时遇到的一些最常见的问题(”陷阱”)。

将 pytree 节点误认为叶子#

一个常见的需要注意的问题是意外引入 树节点 而不是 叶子

a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))]

# Try to make another pytree with ones instead of zeros.
shapes = jax.tree.map(lambda x: x.shape, a_tree)
jax.tree.map(jnp.ones, shapes)
[(Array([1., 1.], dtype=float32), Array([1., 1., 1.], dtype=float32)),
 (Array([1., 1., 1.], dtype=float32), Array([1., 1., 1., 1.], dtype=float32))]

这里发生的情况是,数组的 shape 是一个元组,这是一个 pytree 节点,其元素作为叶子。因此,在映射中,不是在例如 (2, 3) 上调用 jnp.ones,而是在 23 上调用。

解决方案将取决于具体情况,但有两种广泛适用的选项:

  • 重写代码以避免使用中间的 jax.tree.map()

  • 将元组转换为 NumPy 数组(np.array)或 JAX NumPy 数组(jnp.array),这将使整个序列成为一个叶子节点。

jax.tree_utilNone 的处理#

jax.tree_util 函数将 None 视为 pytree 节点的缺失,而不是叶子节点:

jax.tree.leaves([None, None, None])
[]

要将 None 视为叶子,可以使用 is_leaf 参数:

jax.tree.leaves([None, None, None], is_leaf=lambda x: x is None)
[None, None, None]

自定义 pytrees 及使用意外值进行初始化#

用户定义的 pytree 对象的另一个常见问题是,JAX 变换偶尔会用意外的值初始化它们,因此任何在初始化时进行的输入验证都可能失败。例如:

class MyTree:
  def __init__(self, a):
    self.a = jnp.asarray(a)

register_pytree_node(MyTree, lambda tree: ((tree.a,), None),
    lambda _, args: MyTree(*args))

tree = MyTree(jnp.arange(5.0))

jax.vmap(lambda x: x)(tree)      # Error because object() is passed to `MyTree`.
TypeError: Value '<object object at 0x116ab40b0>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
jax.jacobian(lambda x: x)(tree)  # Error because MyTree(...) is passed to `MyTree`.
/Users/cw/baidu/code/fin_tool/github/jax/venv/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:4252: FutureWarning: None encountered in jnp.array(); this is currently treated as NaN. In the future this will result in an error.
  return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)
TypeError: Value '<object object at 0x116ab4500>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
  • 在第一个例子中,使用 jax.vmap(...)(tree),JAX 的内部使用 object() 值的数组来推断树的结构。

  • 在第二种情况 jax.jacobian(...)(tree) 中,将树映射到树的函数的雅可比矩阵被定义为树的树。

潜在解决方案 1:

  • 自定义 pytree 类的 __init____new__ 方法通常应避免进行任何数组转换或其他输入验证,或者预见并处理这些特殊情况。例如:

class MyTree:
  def __init__(self, a):
    if not (type(a) is object or a is None or isinstance(a, MyTree)):
      a = jnp.asarray(a)
    self.a = a

潜在解决方案 2:

  • 构建你的自定义 tree_unflatten 函数,使其避免调用 __init__。如果你选择这条路径,确保你的 tree_unflatten 函数与 __init__ 保持同步,如果代码更新的话。示例:

def tree_unflatten(aux_data, children):
  del aux_data  # Unused in this class.
  obj = object.__new__(MyTree)
  obj.a = a
  return obj

常见的 pytree 模式#

本节涵盖了JAX pytrees中的一些最常见模式。

使用 jax.tree.mapjax.tree.transpose 转置 pytrees#

要将一个 pytree 转置(将树列表转换为列表树),JAX 有两个函数:{func} jax.tree.map(更基础)和 jax.tree.transpose()(更灵活、复杂且冗长)。

选项 1: 使用 jax.tree.map()。以下是一个示例:

def tree_transpose(list_of_trees):
  """
  Converts a list of trees of identical structure into a single tree of lists.
  """
  return jax.tree.map(lambda *xs: list(xs), *list_of_trees)

# Convert a dataset from row-major to column-major.
episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)]
tree_transpose(episode_steps)
{'obs': [3, 4], 't': [1, 2]}

选项 2: 对于更复杂的转置,使用 jax.tree.transpose(),它更详细,但允许您指定内部和外部 pytree 的结构以获得更大的灵活性。例如:

jax.tree.transpose(
  outer_treedef = jax.tree.structure([0 for e in episode_steps]),
  inner_treedef = jax.tree.structure(episode_steps[0]),
  pytree_to_transpose = episode_steps
)
{'obs': [3, 4], 't': [1, 2]}