jax.experimental.custom_partitioning
模块#
API#
- jax.experimental.custom_partitioning.custom_partitioning(fun, static_argnums=())[源代码][源代码]#
将一个 CustomCallOp 插入到 XLA 图中,并带有自定义的 SPMD 降低规则。
@custom_partitioning def f(*args): return ... def propagate_user_sharding(mesh, user_shape): '''Update the sharding of the op from a user's shape.sharding.''' user_sharding = jax.tree.map(lambda x: x.sharding, user_shape) def partition(mesh, arg_shapes, result_shape): def lower_fn(*args): ... builds computation on per-device shapes ... result_shardings = jax.tree.map(lambda x: x.sharding, result_shape) arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) # result_sharding and arg_shardings may optionally be modified and the # partitioner will insert collectives to reshape. return mesh, lower_fn, result_sharding, arg_shardings def infer_sharding_from_operands(mesh, arg_shapes, shape): '''Compute the result sharding from the sharding of the operands.''' arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) f.def_partition(partition, propagate_user_sharding, infer_sharding_from_operands)
def_partition
的参数如下:propagate_user_sharding
: 一个可调用对象,它接收用户的分片(在dag中)并返回一个新的 NamedSharding 建议。默认实现只是返回建议的分片。partition
: 可调用对象,它接收SPMD建议的分区形状和分区规格,并返回网格、每个分片的降低函数以及最终的输入和输出分片规格(SPMD分区器将重新分区输入以匹配)。返回网格以允许在没有提供网格时为集体配置axis_names。infer_sharding_from_operands
: 可调用对象,用于根据每个参数选择的NamedSharding
计算输出NamedSharding
。decode_shardings
: 当设置为 True 时,如果可能,将输入的GSPMDSharding
转换为NamedSharding
。如果用户没有提供上下文网格,这可能无法实现。
位置参数可以通过 static_argnums 指定为静态。JAX 使用
inspect.signature(fun)
来解析这些位置参数。示例
作为一个例子,假设我们想要增强现有的
jax.numpy.fft.fft
。这个函数计算沿着最后一个维度的 N 维输入的离散傅里叶变换,并且在第一个 N-1 维度上进行批处理。然而,默认情况下,它会忽略输入的分片并在所有设备上收集输入。然而,由于jax.numpy.fft.fft
在第一个 N-1 维度上进行了批处理,这是不必要的。我们将创建一个新的my_fft
操作,它不会改变第一个 N-1 维度的分片,并且仅在需要时沿着最后一个维度收集输入。import jax from jax.sharding import NamedSharding from jax.experimental.custom_partitioning import custom_partitioning from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as P from jax.sharding import Mesh from jax.numpy.fft import fft import regex as re import numpy as np # Pattern to detect all-gather or dynamic-slice in the generated HLO _PATTERN = '(dynamic-slice|all-gather)' # For an N-D input, keeps sharding along the first N-1 dimensions # but replicate along the last dimension def supported_sharding(sharding, shape): rank = len(shape.shape) max_shared_dims = min(len(sharding.spec), rank-1) names = tuple(sharding.spec[:max_shared_dims]) + tuple(None for _ in range(rank - max_shared_dims)) return NamedSharding(sharding.mesh, P(*names)) def partition(mesh, arg_shapes, result_shape): result_shardings = jax.tree.map(lambda x: x.sharding, result_shape) arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) return mesh, fft, supported_sharding(arg_shardings[0], arg_shapes[0]), (supported_sharding(arg_shardings[0], arg_shapes[0]),) def infer_sharding_from_operands(mesh, arg_shapes, result_shape): arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) return supported_sharding(arg_shardings[0], arg_shapes[0]) @custom_partitioning def my_fft(x): return fft(x) my_fft.def_partition( infer_sharding_from_operands=infer_sharding_from_operands, partition=partition)
现在创建一个沿第一个轴分片的2D数组,将其传递给
my_fft
并注意它如何仍然按预期分片,并且与fft
的输出相同。然而,检查HLO(使用lower(x).compile().runtime_executable().hlo_modules()
)显示my_fft
没有创建任何全收集或动态切片,而fft
则有。with Mesh(np.array(jax.devices()), ('x',)): x = np.asarray(np.random.randn(32*1024, 1024), dtype=np.complex64) y = pjit(lambda x: x, in_shardings=None, out_shardings=P('x'))(x) pjit_my_fft = pjit(my_fft, in_shardings=P('x'), out_shardings=P('x')) pjit_fft = pjit(fft, in_shardings=P('x'), out_shardings=P('x')) print(pjit_my_fft(y)) print(pjit_fft(y)) # dynamic-slice or all-gather are not present in the HLO for my_fft, because x is a 2D array assert(re.search(_PATTERN, pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is None) # dynamic-slice or all-gather are present in the HLO for fft assert(re.search(_PATTERN, pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is not None)
# my_fft [[-38.840824 +0.j -40.649452 +11.845365j ... -1.6937828 +0.8402481j 15.999859 -4.0156755j]] # jax.numpy.fft.fft [[-38.840824 +0.j -40.649452 +11.845365j ... -1.6937828 +0.8402481j 15.999859 -4.0156755j]]
由于
supported_sharding
中的逻辑,my_fft
也可以处理一维数组。然而,在这种情况下,my_fft
的 HLO 确实显示了一个动态切片,因为最后一个维度是计算 FFT 的维度,需要在所有设备上复制该维度后才能进行计算。with Mesh(np.array(jax.devices()), ('x',)): x = np.asarray(np.random.randn(32*1024*1024), dtype=np.complex64) y = pjit(lambda x: x, in_shardings=None, out_shardings=P('x'))(x) pjit_my_fft = pjit(my_fft, in_shardings=P('x'), out_shardings=P('x')) pjit_fft = pjit(fft, in_shardings=P('x'), out_shardings=P('x')) print(pjit_my_fft(y)) print(pjit_fft(y)) # dynamic-slice or all-gather are present in the HLO for my_fft, because x is a 1D array assert(re.search(_PATTERN, pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is None) # dynamic-slice or all-gather are present in the HLO for fft assert(re.search(_PATTERN, pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is not None)
# my_fft [ 7.217285 +0.j -3012.4937 +4287.635j -405.83594 +3042.984j ... 1422.4502 +7271.4297j -405.84033 -3042.983j -3012.4963 -4287.6343j] # jax.numpy.fft.fft [ 7.217285 +0.j -3012.4937 +4287.635j -405.83594 +3042.984j ... 1422.4502 +7271.4297j -405.84033 -3042.983j -3012.4963 -4287.6343j]