使用 C++ 和 CUDA 进行 GPU 的自定义操作#
JAX 自带了大量内置操作,但用户偶尔会遇到需要一个 JAX 不支持的新操作的情况。
为了适应这些场景,JAX 允许用户定义自定义操作,本教程将解释我们如何为 GPU 定义一个操作,并在单 GPU 和多 GPU 环境中使用它。
本教程包含来自 使用自定义 C++ 和 CUDA 代码扩展 JAX 的信息,并假设您熟悉 JAX 原语。
RMS 归一化#
在本教程中,我们将在 JAX 中添加 RMS 归一化作为自定义操作。请注意,RMS 归一化可以直接使用 jax.numpy
表达。然而,我们将其用作示例,以展示为 GPU 创建自定义操作的过程。gpu_ops/rms_norm_kernels.cu
中的 CUDA 代码已从 Apex 借用,并进行了修改以消除对 PyTorch 的任何依赖。
高级步骤#
本教程展示了如何编写自定义操作及其梯度。
在C语言中:对于每个新的JAX原语,你需要遵循以下步骤:
具有 CUDA 内核。
创建一个 C 函数,该函数将调度由 XLA 调用的 CUDA 内核。
创建一个描述符以传达计算所需的信息。
类型、形状和其他属性。
将C函数绑定到Python
创建描述符并在执行期间调用原语。
在 Python 中:你需要按照以下步骤操作:
定义一个新的 JAX 原语(指令/操作)
编写 Python 函数以使用原语构建图节点。
定义其抽象评估。
定义其降低到 MLIR 的过程。
[可选] 定义渐变。
[可选] 使用 custom_partitioning 或 shard_map 函数进行快速多GPU处理。
C 代码#
请参阅gpu_ops
代码列表以获取完整的C++和CUDA文件代码列表。gpu_ops/rms_norm_kernels.cu
定义了以下函数,这些函数使用XLA自定义函数签名声明。这些函数负责在指定的stream
上使用给定的buffers
启动RMS归一化内核。
namespace gpu_ops {
void rms_forward_affine_mixed_dtypes(cudaStream_t stream, void **buffers,
const char *opaque,
std::size_t opaque_len);
void rms_backward_affine(cudaStream_t stream, void **buffers,
const char *opaque, std::size_t opaque_len);
} // namespace gpu_ops
stream
是用于在GPU上执行任何内核的CUDA流。buffers
包含所有指向输入缓冲区的指针,随后是所有指向输出缓冲区的指针。opaque
是一个用于传递给自定义函数的任何额外信息的缓冲区,而opaque_len
是opaque
的长度。
在本教程中,一个 RMSNormDescriptor
对象将作为 不透明
传递给这些函数。
namespace gpu_ops {
enum ElementType { BF16, F16, F32, F64 };
struct RMSNormDescriptor {
int n1;
int n2;
double eps;
ElementType x_type;
ElementType w_type;
int part_grad_size;
};
} // namespace gpu_ops
现在,我们需要将这些函数以及 ElementType
和 RMSNormDescriptor
通过 pybind11
作为 Python 模块 gpu_ops
暴露出来。
pybind11::dict RMSNormRegistrations() {
pybind11::dict dict;
dict["rms_forward_affine_mixed_dtype"] =
gpu_ops::EncapsulateFunction(gpu_ops::rms_forward_affine_mixed_dtypes);
dict["rms_backward_affine"] =
gpu_ops::EncapsulateFunction(gpu_ops::rms_backward_affine);
return dict;
}
PYBIND11_MODULE(gpu_ops, m) {
m.def("get_rms_norm_registrations", &RMSNormRegistrations);
m.def("create_rms_norm_descriptor",
[](int n1, int n2, double eps, gpu_ops::ElementType x_type,
gpu_ops::ElementType w_type, int part_grad_size) {
return gpu_ops::PackDescriptor(gpu_ops::RMSNormDescriptor{
n1, n2, eps, x_type, w_type, part_grad_size});
});
pybind11::enum_<gpu_ops::ElementType>(m, "ElementType")
.value("BF16", gpu_ops::ElementType::BF16)
.value("F16", gpu_ops::ElementType::F16)
.value("F32", gpu_ops::ElementType::F32)
.value("F64", gpu_ops::ElementType::F64);
}
构建 gpu_ops
扩展模块#
我们使用上述代码构建了 gpu_ops
Python 扩展模块。(有关 C++ 和 CUDA 文件的完整代码列表,请参见gpu_ops
代码列表。)
python -m pip install pybind11==2.10.1
mkdir -p build
pybind_include_path=$(python -c "import pybind11; print(pybind11.get_include())")
python_executable=$(python -c 'import sys; print(sys.executable)')
nvcc --threads 4 -Xcompiler -Wall -ldl --expt-relaxed-constexpr -O3 -DNDEBUG -Xcompiler -O3 --generate-code=arch=compute_70,code=[compute_70,sm_70] --generate-code=arch=compute_75,code=[compute_75,sm_75] --generate-code=arch=compute_80,code=[compute_80,sm_80] --generate-code=arch=compute_86,code=[compute_86,sm_86] -Xcompiler=-fPIC -Xcompiler=-fvisibility=hidden -x cu -c gpu_ops/rms_norm_kernels.cu -o build/rms_norm_kernels.cu.o
c++ -I/usr/local/cuda/include -I$pybind_include_path $(${python_executable}-config --cflags) -O3 -DNDEBUG -O3 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o build/gpu_ops.cpp.o -c gpu_ops/gpu_ops.cpp
c++ -fPIC -O3 -DNDEBUG -O3 -flto -shared -o build/gpu_ops$(${python_executable}-config --extension-suffix) build/gpu_ops.cpp.o build/rms_norm_kernels.cu.o -L/usr/local/cuda/lib64 -lcudadevrt -lcudart_static -lrt -lpthread -ldl
strip build/gpu_ops$(${python_executable}-config --extension-suffix)
将RMS归一化添加到JAX作为自定义调用#
gpu_ops
只是一个 Python 扩展模块,我们需要更多的工作来将其接入 JAX。
创建基本元素#
我们首先创建原语 _rms_norm_fwd_p
和 _rms_norm_bwd_p
,自定义函数可以映射到这些原语。我们将这些操作的 multiple_results
属性设置为 True
,这意味着该操作产生多个输出作为元组。当设置为 False
时,该操作产生单个输出而不带元组。更多详情,请参阅 JAX 原语的工作原理。
from functools import partial
import jax
import jax.numpy as jnp
import jax._src.test_util as jtu
from build import gpu_ops
from jax import core, dtypes
from jax.interpreters import xla
from jax.lib import xla_client
# Create _rms_norm_fwd_p for forward operation.
_rms_norm_fwd_p = core.Primitive("rms_norm_fwd")
_rms_norm_fwd_p.multiple_results = True
_rms_norm_fwd_p.def_impl(partial(xla.apply_primitive, _rms_norm_fwd_p))
def rms_norm_fwd(x, weight, eps=1e-05):
output, invvar = _rms_norm_fwd_p.bind(x, weight, eps=eps)
return output
# Create _rms_norm_bwd_p for backward operation.
_rms_norm_bwd_p = core.Primitive("rms_norm_bwd")
_rms_norm_bwd_p.multiple_results = True
_rms_norm_bwd_p.def_impl(partial(xla.apply_primitive, _rms_norm_bwd_p))
def rms_norm_bwd(g, invvar, x, weight, eps):
grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind(
g, invvar, x, weight, eps=eps
)
return grad_input, grad_weight
降低到 MLIR 自定义调用#
要将自定义函数映射到新的原语 _rms_norm_fwd_p
和 _rms_norm_bwd_p
,我们需要:
使用
xla_client.register_custom_call_target
将自定义函数注册为自定义调用目标,并且注册降低函数,将基元降低为带有注册自定义调用目标的 MLIR 自定义调用。
函数 _rms_norm_fwd_cuda_lowering
和 _rms_norm_bwd_cuda_lowering
将原语降低为带有来自 gpu_ops
的自定义目标的 MLIR 自定义调用操作。这些函数通过 jax.interpreters.mlir.register_lowering
进行注册。
注意,在降低函数中创建了一个 RMSNormDescriptor
对象,并作为 opaque
传递给自定义调用。
from functools import reduce
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jaxlib.hlo_helpers import custom_call
# Register functions defined in gpu_ops as custom call target for GPUs
for _name, _value in gpu_ops.get_rms_norm_registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="gpu")
def element_type_to_descriptor_type_mapping(element_type):
_element_type_to_descriptor_type_mapping = {
ir.BF16Type.get(): gpu_ops.ElementType.BF16,
ir.F16Type.get(): gpu_ops.ElementType.F16,
ir.F32Type.get(): gpu_ops.ElementType.F32,
ir.F64Type.get(): gpu_ops.ElementType.F64,
}
return _element_type_to_descriptor_type_mapping.get(element_type)
def default_layouts(*shapes):
return [range(len(shape) - 1, -1, -1) for shape in shapes]
def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps):
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
w_type = ir.RankedTensorType(weight.type)
w_shape = w_type.shape
iv_element_type = (
ir.F32Type.get()
if x_type.element_type in [ir.F16Type.get(), ir.BF16Type.get()]
else x_type.element_type
)
n2 = reduce(lambda x, y: x * y, w_shape)
n1 = reduce(lambda x, y: x * y, x_shape) // n2
opaque = gpu_ops.create_rms_norm_descriptor(
n1,
n2,
eps,
element_type_to_descriptor_type_mapping(x_type.element_type),
element_type_to_descriptor_type_mapping(w_type.element_type),
0, # unused
)
out = custom_call(
b"rms_forward_affine_mixed_dtype",
result_types=[
ir.RankedTensorType.get(x_shape, w_type.element_type),
ir.RankedTensorType.get((n1,), iv_element_type),
],
operands=[x, weight],
backend_config=opaque,
operand_layouts=default_layouts(x_shape, w_shape),
result_layouts=default_layouts(x_shape, (n1,)),
).results
return out
mlir.register_lowering(
_rms_norm_fwd_p,
_rms_norm_fwd_cuda_lowering,
platform="gpu",
)
def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps):
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
w_type = ir.RankedTensorType(weight.type)
w_shape = w_type.shape
iv_type = ir.RankedTensorType(invvar.type)
n2 = reduce(lambda x, y: x * y, w_shape)
n1 = reduce(lambda x, y: x * y, x_shape) // n2
part_grad_shape = ctx.avals_out[-1].shape
opaque = gpu_ops.create_rms_norm_descriptor(
n1,
n2,
eps,
element_type_to_descriptor_type_mapping(x_type.element_type),
element_type_to_descriptor_type_mapping(w_type.element_type),
part_grad_shape[0],
)
out = custom_call(
b"rms_backward_affine",
result_types=[
ir.RankedTensorType.get(x_shape, x_type.element_type),
ir.RankedTensorType.get(w_shape, w_type.element_type),
ir.RankedTensorType.get(part_grad_shape, iv_type.element_type),
],
operands=[grad_output, invvar, x, weight],
backend_config=opaque,
operand_layouts=default_layouts(x_shape, (n1,), x_shape, w_shape),
result_layouts=default_layouts(x_shape, w_shape, part_grad_shape),
).results
return out
mlir.register_lowering(
_rms_norm_bwd_p,
_rms_norm_bwd_cuda_lowering,
platform="gpu",
)
让我们测试一下#
per_core_batch_size=4
seq_len=512
emb_dim=512
x = jax.random.normal(
jax.random.key(0),
shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim),
dtype=jnp.bfloat16,
)
norm_shape = x.shape[-2:]
weight = jnp.ones(norm_shape, dtype=jnp.bfloat16)
测试前向函数#
out = rms_norm_fwd(x, weight)
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
Cell In [5], line 1
----> 1 out = rms_norm_fwd(x, weight)
...
NotImplementedError: Abstract evaluation for 'rms_norm_fwd' not implemented
抽象评估#
上述测试因 NotImplementedError: Abstract evaluation for 'rms_norm_fwd' not implemented
而失败。测试为何失败?这意味着什么?
作为执行的一部分,JAX 执行抽象评估。由于 JAX 对新原语一无所知,它不知道如何计算输出形状和输出数据类型,因此无法抽象地评估这些操作。
我们需要为每个基本操作提供一个抽象评估函数。这些抽象评估函数计算输出的形状和数据类型,但不计算操作的实际值。
这些函数被传递给 .def_abstract_eval
方法,以便与相应的原语一起注册。
更多关于抽象评估的信息,请参见 JAX 原语如何工作。
from functools import reduce
from operator import mul
from jax.core import ShapedArray
def _rms_norm_fwd_abstract(x, weight, eps):
w_dtype = dtypes.canonicalize_dtype(weight.dtype)
iv_dtype = dtypes.canonicalize_dtype(x.dtype)
if iv_dtype in [jnp.float16, jnp.bfloat16]:
iv_dtype = jnp.float32
n2 = reduce(mul, weight.shape)
n1 = reduce(mul, x.shape) // n2
return (
ShapedArray(x.shape, w_dtype, named_shape=x.named_shape), # output
ShapedArray((n1,), iv_dtype, named_shape=x.named_shape), # invvar
)
_rms_norm_fwd_p.def_abstract_eval(_rms_norm_fwd_abstract)
def _rms_norm_bwd_abstract(grad_output, invvar, x, weight, eps):
iv_dtype = dtypes.canonicalize_dtype(invvar.dtype)
w_dtype = dtypes.canonicalize_dtype(weight.dtype)
x_dtype = dtypes.canonicalize_dtype(x.dtype)
n2 = reduce(lambda x, y: x * y, weight.shape)
n1 = reduce(lambda x, y: x * y, x.shape) // n2
part_grad_shape = (16, n2)
assert dtypes.canonicalize_dtype(grad_output.dtype) == w_dtype
assert grad_output.shape == x.shape
assert invvar.shape == (n1,)
assert (
iv_dtype == jnp.float32 if x_dtype in [jnp.float16, jnp.bfloat16] else x_dtype
)
assert grad_output.named_shape == x.named_shape
weight_named_shape = (
weight_named_shape if weight.named_shape else x.named_shape
)
return (
ShapedArray(
x.shape, x_dtype, named_shape=x.named_shape
), # grad input
ShapedArray(
weight.shape, w_dtype, named_shape=weight_named_shape
), # grad weight
ShapedArray(
part_grad_shape, iv_dtype, named_shape=weight_named_shape
), # part grad
)
_rms_norm_bwd_p.def_abstract_eval(_rms_norm_bwd_abstract)
让我们再测试一次#
测试前向函数#
out = rms_norm_fwd(x, weight)
测试反向函数#
现在让我们使用 jax.grad
和 jtu.check_grads
测试反向操作。
def loss(x, weight):
predictions = rms_norm_fwd(x, weight)
return -jnp.mean(predictions**2)
loss_grad = jax.grad(loss)
out = loss_grad(x, weight)
jtu.check_grads(loss, (x, weight), modes=["rev"], order=1)
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
Cell In [8], line 7
3 return -jnp.mean(predictions**2)
6 loss_grad = jax.grad(loss)
----> 7 out = loss_grad(x, weight)
...
NotImplementedError: Differentiation rule for 'rms_norm_fwd' not implemented
微分规则#
反向操作失败,错误为 NotImplementedError: 未实现 'rms_norm_fwd' 的微分规则
。这意味着,尽管我们已经定义了 rms_norm_fwd
和 rms_norm_bwd
,但 JAX 不知道它们之间的关系。
我们可以使用 jax.custom_vjp
及其约定,教 JAX rms_norm_bwd
是 rms_norm_fwd
的反向操作。作为第一步,我们需要完善 rms_norm_fwd
和 rms_norm_bwd
的定义。
# rms_norm_fwd was previously defined as
#
# def rms_norm_fwd(x, weight, eps=1e-05):
# output, invvar = _rms_norm_fwd_p.bind(x, weight, eps=eps)
# return output
#
def rms_norm_fwd(x, weight, eps=1e-05):
output, invvar = _rms_norm_fwd_p.bind(x, weight, eps=eps)
return output, (invvar, x, weight)
# rms_norm_bwd was previously defined as
#
# def rms_norm_bwd(g, invvar, x, weight, eps):
# grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind(
# g, invvar, x, weight, eps=eps
# )
# return grad_input, grad_weight
#
def rms_norm_bwd(eps, res, g):
invvar, x, weight = res
grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind(
g, invvar, x, weight, eps=eps
)
return grad_input, grad_weight
rms_norm_fwd
现在返回一个额外的输出 (invvar, x, weight)
用于残差数据,而 rms_norm_bwd
接受 eps
、res
和 g
作为参数。
一旦通过 jax.custom_vjp
建立了 rms_norm_fwd
和 rms_norm_bwd
之间的关系,JAX 将确保 rms_norm_fwd
的残差数据作为 res
传递给 rms_norm_bwd
进行反向操作。对于 eps
这样的不可微参数,JAX 确保它们在残差数据之前传递给反向操作。这就是为什么 eps
在 rms_norm_bwd
的参数列表中位于 res
之前。
既然 rms_norm_fwd
返回了残差数据,这对于简单的 RMS 归一化操作是不需要的,我们为此定义了一个包装器 rms_norm
。它简单地调用 rms_norm_fwd
并仅返回 output
。注意 rms_norm
被标注为 @partial(jax.custom_vjp, nondiff_argnums=(2,))
,并且我们将 rms_norm_fwd
和 rms_norm_bwd
传递给 rms_norm.defvjp
。这告诉 JAX,当 rms_norm
被微分时,rms_norm_fwd
用于前向操作,而 rms_norm_bwd
用于后向操作。
有关 jax.custom_vjp
的更多信息,请参阅 JAX 可转换 Python 函数的自定义导数规则。
@partial(jax.custom_vjp, nondiff_argnums=(2,))
def rms_norm(x, weight, eps=1e-05):
output, _ = rms_norm_fwd(x, weight, eps=eps)
return output
rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd)
通过我们进行的改进,反向操作测试现在可以通过以下修改工作:loss
现在调用 rms_norm
而不是 rms_norm_fwd
。
def loss(x, weight):
predictions = rms_norm(x, weight)
return -jnp.mean(predictions**2)
loss_grad = jax.grad(loss)
out = loss_grad(x, weight)
jtu.check_grads(loss, (x, weight), modes=["rev"], order=1)
让我们在多台设备上测试它#
我们使用 jax.experimental.pjit.pjit
在多个设备上进行并行执行,并在单个设备上通过顺序执行生成参考值。
测试前向函数#
首先,我们在多个设备上测试前向操作。我们正在创建一个简单的1D网格,并在所有设备上分片 x
。
from jax.sharding import Mesh, PartitionSpec
from jax.experimental.pjit import pjit
mesh = Mesh(jax.local_devices(), ("x",))
ref = rms_norm(x, weight)
pjitted = pjit(
rms_norm,
# Shard x by batch dimension and replicate weight on all devices.
in_shardings=(PartitionSpec("x", None, None), PartitionSpec(None, None)),
# Shard the output by batch dimension.
out_shardings=PartitionSpec("x", None, None),
)
with mesh:
print(pjitted.lower(x, weight).compile().runtime_executable().hlo_modules()[0].to_string())
out = pjitted(x, weight)
jnp.allclose(ref, out, atol=1e-5, rtol=1e-5)
HloModule pjit_rms_norm, entry_computation_layout={(bf16[4,512,512]{2,1,0},bf16[512,512]{1,0})->bf16[4,512,512]{2,1,0}}
%fused_computation (param_1: bf16[32,512,512], param_1.3: u32[]) -> bf16[4,512,512] {
%param_1 = bf16[32,512,512]{2,1,0} parameter(0)
%param_1.3 = u32[] parameter(1)
%convert.2 = s32[] convert(u32[] %param_1.3), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
%constant_9 = s32[] constant(4), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
%multiply.3 = s32[] multiply(s32[] %convert.2, s32[] %constant_9), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
%constant_8 = s32[] constant(0), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
ROOT %dynamic-slice.2 = bf16[4,512,512]{2,1,0} dynamic-slice(bf16[32,512,512]{2,1,0} %param_1, s32[] %multiply.3, s32[] %constant_8, s32[] %constant_8), dynamic_slice_sizes={4,512,512}, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
}
ENTRY %main.7_spmd (param: bf16[4,512,512], param.1: bf16[512,512]) -> bf16[4,512,512] {
%param = bf16[4,512,512]{2,1,0} parameter(0), sharding={devices=[8,1,1]0,1,2,3,4,5,6,7}
%all-gather = bf16[32,512,512]{2,1,0} all-gather(bf16[4,512,512]{2,1,0} %param), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={0}, use_global_device_ids=true, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
%param.1 = bf16[512,512]{1,0} parameter(1), sharding={replicated}
%custom-call.0 = (bf16[32,512,512]{2,1,0}, f32[32]{0}) custom-call(bf16[32,512,512]{2,1,0} %all-gather, bf16[512,512]{1,0} %param.1), custom_call_target="rms_forward_affine_mixed_dtype", operand_layout_constraints={bf16[32,512,512]{2,1,0}, bf16[512,512]{1,0}}, api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}, backend_config=" \000\000\000\000\000\004\000\361h\343\210\265\370\344>\000\000\000\000\000\000\000\000\000\000\000\000\255\177\000\000"
%get-tuple-element = bf16[32,512,512]{2,1,0} get-tuple-element((bf16[32,512,512]{2,1,0}, f32[32]{0}) %custom-call.0), index=0, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
%partition-id = u32[] partition-id(), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
ROOT %fusion = bf16[4,512,512]{2,1,0} fusion(bf16[32,512,512]{2,1,0} %get-tuple-element, u32[] %partition-id), kind=kLoop, calls=%fused_computation, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
}
True
值已为正向操作正确计算,然而,生成的 HLO 模块显示了一个 all-gather
操作,以在所有设备上复制 x
,这导致了大量的通信开销。
由于 XLA 对自定义函数的了解不足以分片输入张量,它在进行自定义调用之前决定复制它们以生成正确的值。
为了避免这种重复,我们可以:
custom_partitioning: 使其表现得像所有原生 JAX 操作(但更复杂)
使用手动分片
此示例演示了 custom_partitioning 的使用。
检查正确性#
with Mesh(jax.local_devices(), ("x",)):
def run_and_verify(loss):
pjitted = pjit(
jax.grad(loss, argnums=(0, 1)),
# Shard x by batch dimension and replicate weight on all devices.
in_shardings=(
PartitionSpec("x", None, None),
PartitionSpec(None, None),
),
# Shard the output by batch dimension and replicate weight grad on all devices.
out_shardings=(
PartitionSpec("x", None, None),
PartitionSpec(None, None),
),
)
hlo = pjitted.lower(x, weight).compile().as_text()
out = pjitted(x, weight)
print(hlo)
assert "all-reduce-done" in hlo, "The gradient will produce wrong value!"
if "all-gather-start" in hlo:
print("NOT OPTIMIZED, ALL_GATHER in the graph!")
return out
custom_p_out = run_and_verify(custom_p_loss)
for r, o in zip(ref_out, custom_p_out):
print(jnp.allclose(r, o, atol=1e-6, rtol=1e-6))
HloModule pjit_custom_p_loss, is_scheduled=true, entry_computation_layout={(f16[4,512,512]{2,1,0}, f16[512,512]{1,0})->(f16[4,512,512]{2,1,0}, f16[512,512]{1,0})}, allow_spmd_sharding_propagation_to_parameters={false,false}, allow_spmd_sharding_propagation_to_output={false,false}, num_partitions=4, frontend_attributes={fingerprint_before_lhs="d7b9bc40de002332dd665ff2ab537b76"}
%fused_multiply (param_0: f16[4,512,512]) -> f16[4,512,512] {
%param_0 = f16[4,512,512]{2,1,0} parameter(0)
%constant_4_1 = f16[] constant(-4.7684e-07)
%broadcast.8.1 = f16[4,512,512]{2,1,0} broadcast(f16[] %constant_4_1), dimensions={}, metadata={op_name="pjit(custom_p_loss)/jit(main)/mul" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=484}
ROOT %multiply.5.1 = f16[4,512,512]{2,1,0} multiply(f16[4,512,512]{2,1,0} %param_0, f16[4,512,512]{2,1,0} %broadcast.8.1), metadata={op_name="pjit(custom_p_loss)/jit(main)/mul" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=484}
}
%region_0.9._custom_call_lowering_rule (Arg_0.10.0: f16[], Arg_1.11.0: f16[]) -> f16[] {
%Arg_1.11.0 = f16[] parameter(1)
%Arg_0.10.0 = f16[] parameter(0)
ROOT %add.2.0 = f16[] add(f16[] %Arg_0.10.0, f16[] %Arg_1.11.0), metadata={op_name="jit(main)/add" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=433}
}
ENTRY %main.23_spmd (param.2: f16[4,512,512], param.1.0: f16[512,512]) -> (f16[4,512,512], f16[512,512]) {
%param.1.0 = f16[512,512]{1,0} parameter(1), sharding={replicated}
%param.2 = f16[4,512,512]{2,1,0} parameter(0), sharding={devices=[4,1,1]<=[4]}
%custom-call.3.0 = (f16[4,512,512]{2,1,0}, f32[4]{0}) custom-call(f16[4,512,512]{2,1,0} %param.2, f16[512,512]{1,0} %param.1.0), custom_call_target="rms_forward_affine_mixed_dtype", operand_layout_constraints={f16[4,512,512]{2,1,0}, f16[512,512]{1,0}}, api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormFwdClass.partition at 0x7ff99e3980d0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormFwdClass.infer_sharding_from_operands at 0x7ff99e398040> decode_shardings=True in_tree=PyTreeDef((*, *)) out_tree=PyTreeDef((*, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=440}, backend_config="\004\000\000\000\000\000\004\000\361h\343\210\265\370\344>\001\000\000\000\001\000\000\000\000\000\000\000$V\000\000"
%get-tuple-element.14 = f16[4,512,512]{2,1,0} get-tuple-element((f16[4,512,512]{2,1,0}, f32[4]{0}) %custom-call.3.0), index=0, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormFwdClass.partition at 0x7ff99e3980d0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormFwdClass.infer_sharding_from_operands at 0x7ff99e398040> decode_shardings=True in_tree=PyTreeDef((*, *)) out_tree=PyTreeDef((*, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=440}
%loop_multiply_fusion = f16[4,512,512]{2,1,0} fusion(f16[4,512,512]{2,1,0} %get-tuple-element.14), kind=kLoop, calls=%fused_multiply, metadata={op_name="pjit(custom_p_loss)/jit(main)/mul" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=484}
%get-tuple-element.1.0 = f32[4]{0} get-tuple-element((f16[4,512,512]{2,1,0}, f32[4]{0}) %custom-call.3.0), index=1, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormFwdClass.partition at 0x7ff99e3980d0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormFwdClass.infer_sharding_from_operands at 0x7ff99e398040> decode_shardings=True in_tree=PyTreeDef((*, *)) out_tree=PyTreeDef((*, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=440}
%custom-call.5.0 = (f16[4,512,512]{2,1,0}, f16[512,512]{1,0}, f32[16,262144]{1,0}) custom-call(f16[4,512,512]{2,1,0} %loop_multiply_fusion, f32[4]{0} %get-tuple-element.1.0, f16[4,512,512]{2,1,0} %param.2, f16[512,512]{1,0} %param.1.0), custom_call_target="rms_backward_affine", operand_layout_constraints={f16[4,512,512]{2,1,0}, f32[4]{0}, f16[4,512,512]{2,1,0}, f16[512,512]{1,0}}, api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}, backend_config="\004\000\000\000\000\000\004\000\361h\343\210\265\370\344>\001\000\000\000\001\000\000\000\020\000\000\000$V\000\000"
%get-tuple-element.7.0 = f16[512,512]{1,0} get-tuple-element((f16[4,512,512]{2,1,0}, f16[512,512]{1,0}, f32[16,262144]{1,0}) %custom-call.5.0), index=1, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}
%all-reduce-start = f16[512,512]{1,0} all-reduce-start(f16[512,512]{1,0} %get-tuple-element.7.0), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%region_0.9._custom_call_lowering_rule, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"collective_backend_config":{"is_sync":true,"no_parallel_custom_call":false}}
%all-reduce-done = f16[512,512]{1,0} all-reduce-done(f16[512,512]{1,0} %all-reduce-start), metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}
%get-tuple-element.12.0 = f16[4,512,512]{2,1,0} get-tuple-element((f16[4,512,512]{2,1,0}, f16[512,512]{1,0}, f32[16,262144]{1,0}) %custom-call.5.0), index=0, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}
ROOT %tuple.1.0 = (f16[4,512,512]{2,1,0}, f16[512,512]{1,0}) tuple(f16[4,512,512]{2,1,0} %get-tuple-element.12.0, f16[512,512]{1,0} %all-reduce-done)
}
True
True
现在 HLO 中没有 all-gather 操作,分片得到尊重,并且仅通过 all-reduce 累积梯度。
让我们把它放在一起#
使用 custom_partitioning 的基元完整定义可以在 Custom_Operation_for_GPUs.py 中找到,定义了 Python 绑定的相应 C++ 代码以及内核实现可以在下面找到:
gpu_ops
代码列表#
gpu_ops/kernel_helpers.h
gpu_ops/kernels.h
gpu_ops/pybind11_kernel_helpers.h
gpu_ops/gpu_ops.cpp
gpu_ops/rms_norm_kernels.cu