jax.lax 模块#

jax.lax 是一个基础操作库,支撑着 jax.numpy 等库。变换规则,如 JVP 和批处理规则,通常定义为 jax.lax 原语的变换。

许多原语是围绕等效的 XLA 操作的薄包装,这些操作由 XLA 操作语义 文档描述。在少数情况下,JAX 与 XLA 有所不同,通常是为了确保操作集在 JVP 和转置规则的操作下是封闭的。

如果可能,建议使用 jax.numpy 库,而不是直接使用 jax.laxjax.numpy API 遵循 NumPy,因此比 jax.lax API 更稳定,且更不容易发生变化。

运算符#

abs(x)

逐元素绝对值:\(|x|\)

acos(x)

逐元素反余弦:\(\mathrm{acos}(x)\)

acosh(x)

逐元素反双曲余弦:\(\mathrm{acosh}(x)\)

add(x, y)

逐元素加法:\(x + y\)

after_all(*operands)

合并一个或多个 XLA 令牌值。

approx_max_k(operand, k[, ...])

以近似的方式返回 operand 中的最大 k 个值及其索引。

approx_min_k(operand, k[, ...])

以近似方式返回 operand 中的最小 k 值及其索引。

argmax(operand, axis, index_dtype)

计算沿 axis 的最大元素的索引。

argmin(operand, axis, index_dtype)

计算沿 axis 的最小元素的索引。

asin(x)

逐元素反正弦:\(\mathrm{asin}(x)\)

asinh(x)

逐元素反双曲正弦函数:\(\mathrm{asinh}(x)\)

atan(x)

逐元素反正切:\(\mathrm{atan}(x)\)

atan2(x, y)

两个变量的逐元素反正切:\(\mathrm{atan}({x \over y})\)

atanh(x)

逐元素反双曲正切函数:\(\mathrm{atanh}(x)\)

batch_matmul(lhs, rhs[, precision])

批量矩阵乘法。

bessel_i0e(x)

指数缩放的0阶修正贝塞尔函数:\(\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)\)

bessel_i1e(x)

指数缩放的修正贝塞尔函数,阶数为1:\(\mathrm{i1e}(x) = e^{-|x|} \mathrm{i1}(x)\)

betainc(a, b, x)

逐元素正则化不完全贝塔积分。

bitcast_convert_type(operand, new_dtype)

逐元素位转换。

bitwise_and(x, y)

逐元素与运算:\(x \wedge y\)

bitwise_not(x)

逐元素非运算::math:` eg x`。

bitwise_or(x, y)

逐元素或运算:\(x \vee y\)

bitwise_xor(x, y)

逐元素异或运算:\(x \oplus y\)

population_count(x)

逐元素popcount,计算每个元素中设置的位数。

broadcast(operand, sizes)

广播一个数组,添加新的前导维度

broadcast_in_dim(operand, shape, ...)

封装了 XLA 的 BroadcastInDim 操作符。

broadcast_shapes()

返回由NumPy广播 shapes 得到的结果形状。

broadcast_to_rank(x, rank)

x 的维度前添加 1 以使其秩为 rank

broadcasted_iota(dtype, shape, dimension)

围绕 iota 的便捷包装器。

cbrt(x)

逐元素立方根:\(\sqrt[3]{x}\)

ceil(x)

逐元素上限:\(\left\lceil x \right\rceil\)

clamp(min, x, max)

逐元素夹紧。

clz(x)

逐元素计算前导零的个数。

collapse(operand, start_dimension[, ...])

将数组的维度折叠成单一维度。

complex(x, y)

逐元素生成复数:\(x + jy\)

concatenate(operands, dimension)

沿 dimension 维度连接一系列数组。

conj(x)

逐元素复共轭函数:\(\overline{x}\)

conv(lhs, rhs, window_strides, padding[, ...])

conv_general_dilated 的便捷包装器。

convert_element_type(operand, new_dtype)

逐元素类型转换。

conv_dimension_numbers(lhs_shape, rhs_shape, ...)

将卷积 dimension_numbers 转换为 ConvDimensionNumbers

conv_general_dilated(lhs, rhs, ...[, ...])

通用 n 维卷积操作符,可选的膨胀。

conv_general_dilated_local(lhs, rhs, ...[, ...])

通用 n 维非共享卷积运算符,可选膨胀。

conv_general_dilated_patches(lhs, ...[, ...])

提取受 conv_general_dilated 感受野约束的补丁。

conv_transpose(lhs, rhs, strides, padding[, ...])

用于计算 N 维卷积“转置”的便捷包装器。

conv_with_general_padding(lhs, rhs, ...[, ...])

conv_general_dilated 的便捷包装器。

cos(x)

逐元素余弦:\(\mathrm{cos}(x)\)

cosh(x)

逐元素双曲余弦:\(\mathrm{cosh}(x)\)

cumlogsumexp(operand[, axis, reverse])

计算沿 axis 的累积 logsumexp。

cummax(operand[, axis, reverse])

计算沿 axis 的累积最大值。

cummin(operand[, axis, reverse])

沿 axis 计算累积最小值。

cumprod(operand[, axis, reverse])

计算沿 的累积乘积。

cumsum(operand[, axis, reverse])

计算沿 axis 的累积和。

digamma(x)

逐元素的digamma函数:\(\psi(x)\)

div(x, y)

逐元素除法:\(x \over y\)

dot(lhs, rhs[, precision, ...])

向量/向量、矩阵/向量和矩阵/矩阵乘法。

dot_general(lhs, rhs, dimension_numbers[, ...])

通用点积/收缩运算符。

dynamic_index_in_dim(operand, index[, axis, ...])

围绕 dynamic_slice 的便捷包装器,用于执行整数索引。

dynamic_slice(operand, start_indices, ...)

封装了XLA的 DynamicSlice 操作符。

dynamic_slice_in_dim(operand, start_index, ...)

围绕 lax.dynamic_slice() 应用于一个维度的便捷包装器。

dynamic_update_index_in_dim(operand, update, ...)

围绕 dynamic_update_slice() 的便捷包装器,用于在单个 axis 中更新大小为 1 的切片。

dynamic_update_slice(operand, update, ...)

封装了 XLA 的 DynamicUpdateSlice 操作符。

dynamic_update_slice_in_dim(operand, update, ...)

围绕 dynamic_update_slice() 的便捷包装器,用于在单个 中更新切片。

eq(x, y)

逐元素相等:\(x = y\)

erf(x)

逐元素误差函数:\(\mathrm{erf}(x)\)

erfc(x)

逐元素互补误差函数:\(\mathrm{erfc}(x) = 1 - \mathrm{erf}(x)\)

erf_inv(x)

逐元素逆误差函数:\(\mathrm{erf}^{-1}(x)\)

exp(x)

逐元素指数:\(e^x\)

expand_dims(array, dimensions)

在数组中插入任意数量的尺寸为1的维度。

expm1(x)

逐元素 \(e^{x} - 1\)

fft(x, fft_type, fft_lengths)

floor(x)

逐元素向下取整:\(\left\lfloor x \right\rfloor\)

full(shape, fill_value[, dtype, sharding])

返回一个用 fill_value 填充的 shape 数组。

full_like(x, fill_value[, dtype, shape, ...])

基于示例数组 x 创建一个完整的数组,类似于 np.full。

gather(operand, start_indices, ...[, ...])

收集操作符。

ge(x, y)

逐元素大于等于:\(x \geq y\)

gt(x, y)

逐元素大于:\(x > y\)

igamma(a, x)

逐元素正则化不完全伽马函数。

igammac(a, x)

逐元素互补正则化不完全伽马函数。

imag(x)

逐元素提取虚部:\(\mathrm{Im}(x)\)

index_in_dim(operand, index[, axis, keepdims])

围绕 lax.slice() 的便捷包装器,用于执行整数索引。

index_take(src, idxs, axes)

integer_pow(x, y)

逐元素幂:\(x^y\),其中 \(y\) 是一个固定的整数。

iota(dtype, size)

封装了 XLA 的 Iota 操作符。

is_finite(x)

逐元素 \(\mathrm{isfinite}\)

le(x, y)

逐元素小于等于:\(x \leq y\)

lgamma(x)

逐元素对数伽马函数:\(\mathrm{log}(\Gamma(x))\)

log(x)

逐元素自然对数:\(\mathrm{log}(x)\)

log1p(x)

逐元素计算 \(\mathrm{log}(1 + x)\)

logistic(x)

逐元素逻辑(sigmoid)函数:\(\frac{1}{1 + e^{-x}}\)

lt(x, y)

逐元素小于:\(x < y\)

max(x, y)

逐元素最大值:\(\mathrm{max}(x, y)\)

min(x, y)

逐元素最小值:\(\mathrm{min}(x, y)\)

mul(x, y)

逐元素乘法: \(x imes y\).

ne(x, y)

逐元素不等于:\(x eq y\)

neg(x)

逐元素取反:\(-x\)

nextafter(x1, x2)

返回在 x2 方向上 x1 之后的下一个可表示的值。

optimization_barrier(operand, /)

防止编译器将操作跨越屏障移动。

pad(operand, padding_value, padding_config)

对数组应用低、高和/或内部填充。

platform_dependent(*args[, default])

移除平台特定的代码。

polygamma(m, x)

逐元素多伽玛函数:\(\psi^{(m)}(x)\)

population_count(x)

逐元素popcount,计算每个元素中设置的位数。

pow(x, y)

逐元素幂:\(x^y\)

random_gamma_grad(a, x)

来自 Gamma(a, 1) 的样本的逐元素导数。

real(x)

逐元素提取实部:\(\mathrm{Re}(x)\)

reciprocal(x)

逐元素倒数:\(1 \over x\)

reduce(operands, init_values, computation, ...)

封装了 XLA 的 Reduce 操作符。

reduce_precision(operand, exponent_bits, ...)

封装了 XLA 的 ReducePrecision 操作符。

reduce_window(operand, init_value, ...[, ...])

rem(x, y)

逐元素取余:\(x \bmod y\)

reshape(operand, new_sizes[, dimensions])

封装了XLA的 Reshape 操作符。

rev(operand, dimensions)

封装了 XLA 的 Rev 操作符。

rng_bit_generator(key, shape[, dtype, algorithm])

无状态的伪随机数生成器位生成器。

rng_uniform(a, b, shape)

有状态的伪随机数生成器。

round(x[, rounding_method])

逐元素四舍五入。

rsqrt(x)

逐元素的倒数平方根: \(1 \over \sqrt{x}\)

scatter(operand, scatter_indices, updates, ...)

散点更新操作符。

scatter_add(operand, scatter_indices, ...[, ...])

散点加法运算符。

scatter_apply(operand, scatter_indices, ...)

分散-应用操作符。

scatter_max(operand, scatter_indices, ...[, ...])

Scatter-max 操作符。

scatter_min(operand, scatter_indices, ...[, ...])

Scatter-min 操作符。

scatter_mul(operand, scatter_indices, ...[, ...])

散点乘法运算符。

shift_left(x, y)

逐元素左移:\(x \ll y\)

shift_right_arithmetic(x, y)

逐元素算术右移:\(x \gg y\)

shift_right_logical(x, y)

逐元素逻辑右移:\(x \gg y\)

sign(x)

逐元素符号。

sin(x)

逐元素正弦:\(\mathrm{sin}(x)\)

sinh(x)

逐元素双曲正弦函数:\(\mathrm{sinh}(x)\)

slice(operand, start_indices, limit_indices)

封装了 XLA 的 Slice 操作符。

slice_in_dim(operand, start_index, limit_index)

围绕 lax.slice() 的便捷包装器,仅应用于一个维度。

sort()

封装了 XLA 的 排序 操作符。

sort_key_val(keys, values[, dimension, ...])

沿 dimensionkeys 进行排序,并将相同的排列应用于 values

sqrt(x)

逐元素平方根:\(\sqrt{x}\)

square(x)

逐元素平方:\(x^2\)

squeeze(array, dimensions)

从数组中挤压任意数量的尺寸为1的维度。

sub(x, y)

逐元素减法:\(x - y\)

tan(x)

逐元素正切:\(\mathrm{tan}(x)\)

tanh(x)

逐元素双曲正切函数:\(\mathrm{tanh}(x)\)

top_k(operand, k)

返回 operand 最后一个轴上的前 k 个值及其索引。

transpose(operand, permutation)

封装了 XLA 的 Transpose 操作符。

zeros_like_array(x)

zeta(x, q)

逐元素Hurwitz zeta函数:\(\zeta(x, q)\)

控制流操作符#

associative_scan(fn, elems[, reverse, axis])

在并行环境中执行带有关联二元操作的扫描。

cond(pred, true_fun, false_fun, *operands[, ...])

有条件地应用 true_funfalse_fun

fori_loop(lower, upper, body_fun, init_val, *)

lowerupper 通过归约为 jax.lax.while_loop() 进行循环。

map(f, xs, *[, batch_size])

对数组的前导轴应用函数。

scan(f, init[, xs, length, reverse, unroll, ...])

在沿用状态的同时,扫描函数遍历数组的前导轴。

select(pred, on_true, on_false)

根据布尔谓词在两个分支之间进行选择。

select_n(which, *cases)

从多个案例中选择数组值。

switch(index, branches, *operands[, operand])

应用由 index 给出的 branches 中的一个。

while_loop(cond_fun, body_fun, init_val)

cond_fun 为 True 时,循环调用 body_fun

自定义梯度运算符#

stop_gradient(x)

停止梯度计算。

custom_linear_solve(matvec, b, solve[, ...])

使用隐式定义的梯度执行无矩阵线性求解。

custom_root(f, initial_guess, solve, ...[, ...])

可微分地求解函数的根。

并行运算符#

all_gather(x, axis_name, *[, ...])

收集所有副本中的 x 值。

all_to_all(x, axis_name, split_axis, ...[, ...])

实现映射的轴并映射不同的轴。

psum(x, axis_name, *[, axis_index_groups])

x 上计算 axis_name 轴上的 pmapped 轴的 all-reduce 和。

psum_scatter(x, axis_name, *[, ...])

类似于 psum(x, axis_name),但每个设备只保留结果的一部分。

pmax(x, axis_name, *[, axis_index_groups])

axis_name 的 pmapped 轴上对 x 进行 all-reduce max 计算。

pmin(x, axis_name, *[, axis_index_groups])

axis_name 的 pmapped 轴上对 x 进行 all-reduce min 计算。

pmean(x, axis_name, *[, axis_index_groups])

axis_name 的 pmapped 轴上对 x 进行 all-reduce 均值计算。

ppermute(x, axis_name, perm)

根据排列 perm 执行集体排列。

pshuffle(x, axis_name, perm)

jax.lax.ppermute 的便捷包装,具有替代的排列编码

pswapaxes(x, axis_name, axis, *[, ...])

将pmapped轴 axis_name 与未映射的轴 axis 交换。

axis_index(axis_name)

返回沿映射轴 axis_name 的索引。

线性代数运算符 (jax.lax.linalg)#

cholesky(x, *[, symmetrize_input])

Cholesky 分解。

eig(x, *[, compute_left_eigenvectors, ...])

一般矩阵的特征分解。

eigh(x, *[, lower, symmetrize_input, ...])

埃尔米特矩阵的特征分解。

hessenberg(a)

将一个方阵简化为上Hessenberg形式。

lu(x)

部分枢轴旋转的LU分解。

householder_product(a, taus)

基本Householder反射器的乘积。

qdwh(x, *[, is_hermitian, max_iterations, ...])

基于QR的动态加权Halley迭代用于极分解。

qr(x, *[, full_matrices])

QR 分解。

schur(x, *[, compute_schur_vectors, ...])

svd()

奇异值分解。

triangular_solve(a, b, *[, left_side, ...])

三角求解。

tridiagonal(a, *[, lower])

将一个对称/厄米矩阵简化为三对角形式。

tridiagonal_solve(dl, d, du, b)

计算三对角线性系统的解。

参数类#

class jax.lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)[源代码][源代码]#

描述卷积的批次、空间和特征维度。

参数:
  • lhs_spec (Sequence[int]) – 一个包含 (批量维度, 特征维度, 空间维度…) 的非负整数维度数的元组。

  • rhs_spec (Sequence[int]) – 一个包含非负整数维度数的元组 (输出特征维度, 输入特征维度, 空间维度…)

  • out_spec (Sequence[int]) – 一个包含 (批量维度, 特征维度, 空间维度…) 的非负整数维度数的元组。

jax.lax.ConvGeneralDilatedDimensionNumbers#

tuple[str, str, str] | ConvDimensionNumbers | None 的别名

class jax.lax.GatherDimensionNumbers(offset_dims, collapsed_slice_dims, start_index_map)[源代码][源代码]#

描述了传递给 XLA 的 Gather 操作符 的维度编号参数。有关维度编号的更多详细信息,请参阅 XLA 文档。

参数:
  • offset_dims (tuple[int, ...]) – gather 输出中用于偏移到从 operand 切片得到的数组的一组维度。必须是一个按升序排列的整数元组,每个整数代表输出中的一个维度编号。

  • collapsed_slice_dims (tuple[int, ...]) – 在 operand 中,维度 i 的集合,其 slice_sizes[i] == 1,并且不应在 gather 的输出中具有相应的维度。必须是一个按升序排列的整数元组。

  • start_index_map (tuple[int, ...]) – 对于 start_indices 中的每个维度,给出 operand 中要被切片的相应维度。必须是一个整数元组,其大小等于 start_indices.shape[-1]

与 XLA 的 GatherDimensionNumbers 结构不同,index_vector_dim 是隐式的;总是存在一个索引向量维度,并且它必须始终是最后一个维度。要收集标量索引,请添加一个大小为 1 的尾随维度。

class jax.lax.GatherScatterMode(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[源代码][源代码]#

描述了如何在 gather 或 scatter 中处理越界索引。

可能的值是:

CLIP:

索引将被限制为最接近的范围内的值,即,使得要收集的整个窗口都在范围内。

FILL_OR_DROP:

如果收集的窗口的任何部分超出边界,返回的整个窗口,即使是那些原本在边界内的元素,也将被填充为一个常量。如果分散的窗口的任何部分超出边界,整个窗口将被丢弃。

PROMISE_IN_BOUNDS:

用户承诺索引在边界内。不会进行额外的检查。实际上,使用当前的XLA实现,这意味着越界的gather操作将被钳制,而越界的scatter操作将被丢弃。如果索引越界,梯度将不正确。

class jax.lax.Precision(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[源代码][源代码]#

用于宽松矩阵乘法相关函数的精度枚举。

JAX 函数的设备相关 precision 参数通常控制加速器后端(即 TPU 和 GPU)上数组计算的速度和精度之间的权衡。对 CPU 后端没有影响。这仅对 float32 计算有影响,不影响输入/输出数据类型。成员包括:

默认:

最快模式,但准确性最低。在TPU上:以bfloat16执行float32计算。在GPU上:如果可用则使用tensorfloat32(例如在A100和H100 GPU上),否则使用标准float32(例如在V100 GPU上)。别名:'default''fastest'

高:

较慢但更准确。在TPU上:以3次bfloat16传递执行float32计算。在GPU上:在可用时使用tensorfloat32,否则使用float32。别名:'high'

最高:

最慢但最准确。在TPU上:以6个bfloat16执行float32计算。别名:'highest'。在GPU上:使用float32。

jax.lax.PrecisionLike#

str | Precision | tuple[str, str] | tuple[Precision, Precision] | None 的别名

class jax.lax.RoundingMethod(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[源代码][源代码]#
class jax.lax.ScatterDimensionNumbers(update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims)[源代码][源代码]#

描述了 XLA 的 Scatter 操作符 的维度编号参数。有关维度编号的更多详细信息,请参阅 XLA 文档。

参数:
  • update_window_dims (Sequence[int]) – updates 中窗口维度的集合。必须是一个按升序排列的整数元组,每个整数代表一个维度编号。

  • inserted_window_dims (Sequence[int]) – 必须插入到 updates 形状中的大小为 1 的窗口尺寸集合。必须是一个按升序排列的整数元组,每个整数代表输出的一维。在 gather 的情况下,这些是 collapsed_slice_dims 的镜像。

  • scatter_dims_to_operand_dims (Sequence[int]) – 对于 scatter_indices 中的每个维度,给出 operand 中对应的维度。必须是一个整数序列,其大小等于 scatter_indices.shape[-1]

与 XLA 的 ScatterDimensionNumbers 结构不同,index_vector_dim 是隐式的;总是存在一个索引向量维度,并且它必须始终是最后一个维度。要分散标量索引,请添加一个大小为 1 的尾随维度。