高效转置复制诱导的集体#
mattjj@, dougalm@
2023年8月
动机#
我们在自动转置包含某些集合的 shmap
时存在效率问题。问题出现在 psum
和 all_gather
中,特别是当集合的输出作为未映射的输出返回给调用者时。这并不是一个边缘情况:例如,当对基于 shmap
的批量数据并行神经网络损失函数应用 grad
时,使用 psum
计算总损失时就会出现这种情况。
我们已经知道这个问题有一段时间了。pmap
存在类似的问题,尽管通过将 grad
保持在 pmap
内部而不是外部来解决了这个问题。不完全的 avals-with-names 工作的主要目标之一是解决这种转置效率问题的一个版本。本文档借鉴了这些想法,同时扩展和修订了它们,以处理更多情况并使其更容易实现。事实上,这里提出的解决方案只影响 shmap
实现。系统的其余部分不需要更改(还不需要)。
本文档的主要目的是定义这个转置效率问题,并提出一个易于实施的解决方案。
本文档不涉及:
数组上的逻辑轴名称(这里唯一的轴名称与
shmap
和 OGpmap
中的一样);改变自动微分语义(所有数字和(非)错误保持不变,我们只是让事情更高效);
允许用户代码反映任何新信息,或者实际上影响用户代码。
问题:psum
或 all_gather
的高效转置取决于余切是否在设备间不变。#
考虑这个半现实的例子,旨在模拟一个复制的参数批量数据并行损失函数:
devices = jax.devices() # 8 devices
@partial(shmap, mesh=Mesh(devices, ('batch',)),
in_specs=(P(None, None), P('batch', None)),
out_specs=P())
def loss(params, batch):
inputs, targets = batch
predictions = predict(params, inputs)
local_loss = jnp.mean(jnp.sum(predictions - targets, -1))
global_loss = lax.pmean(local_loss, 'batch'))
return global_loss
注意 out_specs=P()
,这表示一个未映射的输出。如果您不熟悉未映射输出的概念,请参阅本文档底部的附录。
loss
示例中的大部分细节并不重要。对我们来说,重要的是我们在最后应用了 psum
(或者更确切地说 pmean = lambda x, name: psum(x, name) / psum(1, name)
)。所以一个简化的版本看起来像这样:
# Example 1: shmap involving psum and unmapped output with inefficient transpose
f1 = shmap(lambda x: psum(g(x), 'i'),
in_specs=P('i'), out_specs=P())
我们甚至通过省略 mesh
参数简化了符号表示。在接下来的示例中,它可以根据上下文推断出来。
转置是什么样子的?用 t
表示函数转置,我们可以通过应用下面的函数 ¿f1_transpose?
高效地计算任意 ybar
的 t(f1)(ybar)
:
# An efficient "transpose" of Example 1 (but don't transpose this again!)
¿f1_transpose? = shmap(t(g), in_specs=P(), out_specs=P('i'))
但这不是我们目前得到的转置 t(f1)。
相反,当前的转置配方大致是我们交换 in_specs
和 out_specs
,对未映射的输出进行一些除法重缩放,并转置主体。因为 psum
是它自己的转置(作为一个全减少和),我们最终产生这个转置:
# The transpose we currently get for Example 1 (which is fine to transpose again)
t(f1) = shmap(lambda ybar: t(g)(psum(ybar / 8, 'i')),
in_specs=P(), out_specs=P('i'))
这个转置得到了正确的数字,但它是浪费的。我们从转置的 in_specs=P()
中静态地知道 ybar
对于每个函数实例具有相同的值,即它的值对于沿名为 i
的网格轴的设备是不变的,然而我们却对它应用了一个 psum
!这仅仅是为了将每个设备上的值乘以 8 而使用了昂贵的通信。(这里的 8 指的是轴 i 的大小。除以 8 来自于原始函数的 out_specs=P()
;它和简单的 psum
基本上相互抵消。)
我们做错了什么?我们没有利用这样一个事实,即对应于 f1
未映射输出的余切 ybar
保证是设备不变的;相反,我们防御性地对它们进行 psum
求和,就好像它们不是设备不变的一样,因为 psum
的转置无法根据其拥有的局部信息确定。有时 psum
是必要的,例如在相对于其第一个参数转置 f2
时:
# Example 2: shmap involving psum and *mapped* output with efficient transpose
f2 = shmap(lambda x, y: psum(g(x), 'i') * y,
in_specs=(P('i'), P('i')), out_specs=P('i'))
# The transpose we currently get for Example 2 is efficient
t(f2, 0) = shmap(lambda y, zbar: t(g)(psum(zbar * y, 'i')),
in_specs=(P('i'), P('i')), out_specs=P('i'))
直观地说,如果我们的转置机制能够区分示例1和示例2,我们就可以通过尽可能避免psum和除法来做得更好。
低效的示例甚至可以更小。考虑转置这个被诅咒的恒等函数:
# Example 3: cursed identity
cursed_identity = shmap(lambda x: x, P(), P())
# Currently we get these inefficient transposes
t(cursed_identity) = shmap(lambda x: psum(x / 8, 'i'), P(), P())
t(t(cursed_identity)) = shmap(lambda x: psum(psum(x / 8 / 8, 'i'), 'i')), P(), P())
...
我们越转置,它就变得越大。真尴尬!
而 psum
并不是唯一的罪魁祸首。对于 all_gather
也有类似的情况:
# Example 4: all_gather to an unmapped output
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())
# Currently we get this inefficient transpose
t(f4) = shmap(lambda ybar: psum_scatter(ybar / 8, 'i'), P(), P('i'))
这个程序有点人为。为什么要进行 all_gather
并将结果输入到一个未映射的输出中,而不是跳过主体中的 all_gather
,直接使用 out_specs=P('i')
来收集结果呢?尽管这个例子是人为设计的,但它仍然展示了一个不必要的转置操作,这个操作执行了通信(我们本可以只执行一个不进行通信的切片),类似于 psum
的示例1。
同样类似于 psum
示例,防御性的 psum_scatter
在某些情况下是必要的:
# Example 5: all_gather to a mapped output
f5 = shmap(lambda x, y: all_gather(x, 'i') * y,
in_specs=(P('i'), P('i')), out_specs=P('i'))
# Currently we get this efficient transpose
t(f5, 0) = shmap(lambda y, zbar: psum_scatter(zbar * y, 'i'),
in_specs=(P('i'), P('i')), out_specs=P('i'))
那么,我们如何避免这些低效的转置呢?
解决方案#
这里有两个解决方案的想法。它们并不互斥。但(剧透)第二个更好,这是我们需要的全部。
部分解决方案 “P-sum”: 在 out_specs
中构建表达 psum
的能力#
这个解决方案有点像稻草人,因为它只会提供一种笨拙的编写程序的方式。而且它甚至不能解决所有问题!但如果能激励出一个更完整的解决方案,那么它还是值得考虑的。
上面的示例4是人为的,因为我们可以直接在主体中使用 out_specs
而不是 all_gather
:
# Example 4 again
f4 = shmap(lambda x: all_gather(x, 'i'), P('i'), P())
# Why didn't we just write it like this?
f4_better = shmap(lambda x: x, P('i'), P('i'))
f4_better
版本没有任何转置问题,因为转置问题源于主体中的集体操作。
类似地,我们可以通过扩展 out_specs
来修复示例1,使其能够表达求和:
# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
in_specs=P('i'), out_specs=P())
# What if we could write an output sum like this?
f1_better = shmap(g, in_specs=P('i'), out_specs=P(sum='i')) # sum='i' means sum over that axis
# Then it could transpose like this:
t(f1_better) = shmap(t(g), in_specs=P(), out_specs=P('i'))
t(t(f1_better)) = shmap(t(t(g)), in_specs=P('i'), P(sum='i'))
因此,将 psum
内置到 out_specs
中解决了示例1的转置问题。但它并没有完全解决示例3中的诅咒身份转置问题:
# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())
# How it would transpose with the P-sum partial solution:
t(cursed_identity) = shmap(lambda x: x / 8, P(), P(sum='i'))
t(t(cursed_identity)) = shmap(lambda x: x / 8, P(), P(sum='i'))
这是一个改进,因为程序不会随着我们不断转置而变得更大,但我们仍在进行浪费的通信。
完整解决方案:静态跟踪设备变化与设备不变的中间结果,加上新的原语#
该解决方案有两个组成部分:
跟踪值在特定网格轴上何时保证设备不变与设备变化,以及
将
psum
分解为两步过程,引入一个新的pbroadcast
原语,并引入新的all_gather
及其转置的原语。
从道德上讲,跟踪设备不变信息与设备变化信息是一种类型级别的考虑。但在我们首次实施的便利性上,我们不需要将这些信息实际添加到抽象值或 jaxpr 类型中。在我们进入实施之前,我们将首先使用类型来介绍这个想法。
接下来将讨论如何使用户API既方便又向后兼容。但首先介绍这个想法时,我们将忽略便利性,而是编写尽可能明确的代码。
在avals中跟踪设备不变性(即avals-with-names,重新启用)#
有时,仅从静态信息中我们就可以断定,在 shmap
的主体中,某些中间变量的值在网格轴上是保证不变的,这意味着沿网格轴的函数实例(及其对应的设备)必须全部使用相同的值进行计算。我们将这些值称为设备不变量。对于那些不是设备不变量的值,我们称它们为设备变化量,尽管实际上我们是从类型系统的角度来理解它们可能是设备变化的。
为了在类型中编码设备差异,我们将扩展数组类型的语法。我们将编写诸如 x:f32[3,4]{i}
的内容,以表明 x
沿网格轴 i
是(潜在地)设备变化的(并且在 shmap
的任何其他网格轴上是设备不变的)。更一般地说,我们将说数组类型语法的语法类似于
shaped_array ::= <dtype>[<int_literal>, ...]<device_variance_type>
device_variance_type ::= {<axis_name>, ...}
我们还将更新类型规则以处理设备差异类型:
对于除集体操作外的其他一阶原语
对于多参数原语,操作数设备差异类型在形状必须相等的地方必须相等,例如
mul x:f32[s1]{r1} y:f32[s2][r2]
除了要求s1 == s2
之外,还要求r1 == r2
。输出设备的方差类型必须与操作数相同。
对于高阶原语
我们只是实例化任何类型变量,包括设备方差类型(并且在检查类型是否相等时,检查它们的设备方差类型是否相等)
(在进行类型推断时,例如对于
cond
的分支,我们取设备方差类型中轴名称集的并集)
对于一阶集体
一个集合体可以接受设备变化或设备不变的输入(沿着与其轴名称参数对应的网格轴);将设备不变的操作数传递给接受设备变化操作数的集合体,反之亦然,都是错误的。
一个集体可以产生设备可变或设备不变的输出
请参见下表 作为额外的好处,无论实现这种类型检查的逻辑是什么,都可以包含
shmap
的“静态分析”检查,以确定shmap
主体函数是否与任何未映射的out_specs
兼容。
以下是一个表格,总结了集体原语的设备变异类型:
名称 |
设备差异类型 |
示例 |
降低到 HLO |
转置 |
---|---|---|---|---|
|
|
|
|
|
|
|
|
no-op (无通信) |
|
|
|
|
|
|
|
|
|
|
n/a |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
这里有一些令人惊讶的事情!
我们引入了几个新的原语,包括
pbroadcast
,有趣的是它降低为一个无操作all_gather_invariant
,它与all_gather
降低到相同的东西,但具有不同的设备方差类型(本质上all_gather
有一个pbroadcast
融合在其中,而all_gather_invariant
没有)pscatter
是all_gather_invariant
的对偶(转置)。
all_gather 有一个设备变化的结果
直观地说,引入 pbroadcast
的原因(除了使类型规则生效之外)是为了让 psum
可以转置为一个物理上的无操作。我们需要 all_gather
产生设备变化的结果,这样我们就可以将其转置为 psum_scatter
;如果我们让它保持设备不变的结果,我们可能需要一个下游的 pbroadcast
,这种组合会转置为一个低效的 psum
后跟切片 / pscatter
。因此,我们有一个 pbroadcast
“融合到” all_gather
中,从而允许高效地转置为 psum_scatter
。我们提供 all_gather_invariant
及其转置 pscatter
主要是为了完整性;用户不太可能需要它(它对应于示例 4 中的情况,可以使用 out_specs
以不同的方式轻松编写)。
有趣的是,psum
和 pbroadcast
转置对对应于用户在使用 pmap
训练 LLMs 时引入的 psum_idrev
和 id_psumrev
。
这个系统如何解决低效的转置示例#
再次考虑这个简化的激励示例:
# Example 1 again
f1 = shmap(lambda x: psum(g(x), 'i'),
in_specs=P('i'), out_specs=P())
# Example 1 with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1(x: f32[3,4]{i}):
w:f32[]{i} = g(x)
y:f32[]{} = psum(w, 'i')
return y
根据这些新规则,转置为:
# Example 1 transpose using device variance types (go ahead and transpose this again!)
t(f1) = shmap(lambda ybar: t(g)(pbroadcast(ybar, 'i')),
in_specs=P(), out_specs=P('i'))
# Example 1 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P('i'), out_specs=P())
def f1_transpose(ybar: f32[]):
wbar:f32[]{i} = pbroadcast(ybar, 'i')
xbar:f32[3,4]{i} = transpose(g)(wbar)
return xbar
在评估 pbroadcast
应用程序时,完全不涉及通信或浮点运算;它是一个空操作。注意,如果我们继续转置,主体的大小不会增加;实际上 t(t(f1)) == f1
。效率达成!
我们也不会搞乱其他示例,只要我们在需要的地方使用 pbroadcast
来检查类型:
# Example 2 rewritten with explicit pbroadcast
f2 = shmap(lambda x, y: pbroadcast(psum(g(x), 'i'), 'i') * y,
in_specs=(P('i'), P('i')), out_specs=P('i'))
# Example 2 transpose using device variance types
t(f2, 0) = shmap(lambda y, zbar: t(g)(pbroadcast(psum(zbar * y, 'i'), 'i')),
in_specs=(P('i'), P('i')), out_specs=P('i'))
# Example 3 again
cursed_identity = shmap(lambda x: x, P(), P())
# Notice here the body is `f32[...] -> f32[...]`, i.e. no device varying type.
# Example 3 transpose using device variance types
t(cursed_identity) = shmap(lambda x: x, P(), P())
t(t(cursed_identity)) = shmap(lambda x: x, P(), P())
直观地看,在示例1中,我们现在只有“原始psum的一半”,而在示例2中,我们得到了“两半”。对于示例3,我们根本不需要在主体中进行任何操作。
对于 all_gather
示例,示例 4 需要使用 all_reduce_invariant
来实现高效转置(尽管最好在主体中使用 out_specs
而不是集体操作):
# Example 4 rewritten with explicit all_reduce_invariant
f4 = shmap(lambda x: all_gather_invariant(x, 'i'), P('i'), P())
# Example 4 with intermediate device variance types annotated
@partial(shmap, P('i'), P())
def f4(x:f32[1]{i}):
y:f32[8]{} = all_gather_invariant(x, 'i')
return y
# Example 4 transpose with intermediate device variance types annotated
@partial(shmap, in_specs=P(), out_specs=P('i'))
def f4_transpose(ybar:f32[8]):
xbar:f32[1]{i} = pscatter(ybar, 'i')
return xbar
例如在示例 5 中,使用设备变化的 all_gather
可以按我们期望的方式工作:
# Example 5 with intermediate device variance types annotated
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5(x:f32[1]{i}, y:f32[8]{i}):
z:f32[8]{i} = all_gather(x, 'i')
w:f32[8]{i} = z * y
return w
# Transpose with respect to first argument
@partial(shmap, in_specs=(P('i'), P('i')), out_specs=P('i'))
def f5_transpose(y:f32[8]{i}, wbar:f32[8]{i}):
zbar:f32[8]{i} = wbar * y
xbar:f32[1]{i} = psum_scatter(zbar, 'i')
return xbar
如何使API对用户方便(且向后兼容)#
但是用户想要写 pbroadcast
吗?开发者又想要破坏大量涉及 psum
的现有用户代码,而这些代码并未输入到未映射的输出中吗?我可不想!
相反,我们可以自动插入 pbroadcast
。这有点类似于我们在 jax.numpy
层进行自动秩提升的方式,插入广播以避免二元运算符中的秩不匹配错误。但由于我们不需要处理形状元组,因此它要简单得多。典型的规则是:每当我们看到操作数在其设备方差类型上存在分歧的多重操作时,取操作数设备方差类型的轴名集合的并集,并插入 pbroadcast
以将每个操作数提升到结果设备方差类型。
在需要之前自动插入 pbroadcast
可能意味着我们对同一个操作数多次应用相同的 pbroadcast
,从而创建公共子表达式。当我们进行转置时,这些可能会变成 psum
的和,而不是和的 psum
。我们将依赖编译器在适当的时候清理这些。如果这是一个问题,那么我们可以在 pbroadcast
插入过程中添加一些简单的记忆化。
all_gather
的用户 API 默认意味着 all_gather_p
(不是 all_gather_invariant_p
),涵盖了常见情况,意味着不需要插入 pbroadcast
。
我们可以在 shmap
上提供一个选项来禁用 pbroadcast
的自动插入,在这种情况下,用户需要确保类型正确性。这个显式选项可能对那些希望在反向传播中明确 psum
出现位置的人有吸引力。
如何实现解决方案#
实现轻量级的关键在于 我们不会将这些类型添加到 avals 或 jaxprs 中。至少,一开始不会。这样做可能会很昂贵,因为它需要更新 JAX 的其余部分,例如所有使用 avals 和 jaxprs 的消费者可能需要处理新类型。我们不会再犯那个错误了!
相反,我们将把这些扩展类型作为元数据保留在 shmap
内部,就像当前的“out_specs
复制检查”机制是 shmap
内部的一样。实际上,这个解决方案相当于对现有机制的一个相对较小的扩展:它已经在跟踪相同的信息;现在我们只是添加了 pbroadcast
。
我们至少有两种选择来执行 pbroadcast
插入的位置:
在转置之前,在转置规则中,我们有一个表示要转置的计算的 jaxpr;
在每个
shmap
体中,无论是立即执行还是暂存,就像当前的“out_specs
复制检查”机制。前者可能更容易,因为我们只需要处理 jaxpr 情况,并且只涉及线性原语。但我们将首先尝试后者,因此这里的实现是对现有复制检查逻辑的严格修订/扩展。
附录:定义和激励具有未映射输入和输出的映射#
为了具体化,我们将主要关注 shmap
,尽管这些相同的想法也适用于例如 pmap
和可能的 xmap
。
当 in_specs
中对应的条目没有提到网格轴的名称时,参数/输入在网格轴上是 未映射 的。逻辑上这意味着沿该网格轴的每个函数实例都获得该参数的相同值。对于调用者来说,每个操作数根据操作数映射的网格轴进行切片,而对于操作数未映射的网格轴则没有切片。
当 out_specs
中没有提及某个网格轴的名称时,输出在该网格轴上是 未映射 的。从逻辑上讲,这意味着沿该网格轴的每个函数实例必须返回相同的值。对于调用者来说,shmap
的每个结果是通过连接沿输出映射的每个函数实例的返回值形成的,而对于输出未映射的网格轴,仅使用该值的一个副本。
参见 the shmap
JEP 以获取未映射输入和输出的示例。作为对比,在 vmap
中,未映射的输入/输出通过使用 in_axes
/ out_axes
为 None
来表示(而不是一个 int
)。
以下是我们喜欢 shmap
的未映射输入和输出的原因:
与
pjit
相同的表达能力。pjit
能做的任何事情,shmap
逃生舱也应该能够做到。否则,我们就会有一个缺乏功能的逃生舱!如果shmap
中没有未映射的输出,那么我们就无法表达与pjit
相同的批量并行损失函数计算。闭包输入。 闭包输入本质上对应于未映射的输入,并且…
转置下的闭包。 一旦我们有了未映射的输入,自然能够转置为未映射的输出。
因此,未映射的输出既规范又实用!