jax.numpy 模块#

使用 jax.lax 中的原语实现了 NumPy API。

虽然 JAX 尽可能地遵循 NumPy API,但有时 JAX 无法完全遵循 NumPy。

  • 值得注意的是,由于 JAX 数组是不可变的,NumPy API 中那些就地修改数组的函数无法在 JAX 中实现。然而,JAX 通常能够提供一个纯函数式的替代 API。例如,JAX 提供了替代的就地数组更新函数 x.at[i].set(y) (参见 ndarray.at),而不是就地数组更新 (x[i] = y)。

  • 同样地,一些 NumPy 函数在可能的情况下通常返回数组的视图(例如 transpose()reshape())。JAX 版本的此类函数将返回副本,尽管在使用 jax.jit() 编译操作序列时,XLA 通常会优化掉这些副本。

  • NumPy 在将值提升为 float64 类型时非常激进。JAX 在类型提升方面有时不那么激进(参见 类型提升)。

  • 一些 NumPy 例程的输出形状依赖于数据(例如 unique()nonzero())。由于 XLA 编译器要求在编译时知道数组形状,因此这些操作与 JIT 不兼容。为此,JAX 为这些函数添加了一个可选的 size 参数,可以在使用 JIT 时静态指定该参数。

几乎所有适用的 NumPy 函数都在 jax.numpy 命名空间中实现;它们列在下面。

ndarray.at

用于索引更新功能的辅助属性。

abs(x, /)

别名 jax.numpy.absolute()

absolute(x, /)

逐元素计算绝对值。

acos(x, /)

别名 jax.numpy.arccos()

acosh(x, /)

别名 jax.numpy.arccosh()

add

逐元素相加两个数组。

all(a[, axis, out, keepdims, where])

测试沿给定轴的所有数组元素是否评估为 True。

allclose(a, b[, rtol, atol, equal_nan])

检查两个数组在容差范围内是否逐元素近似相等。

amax(a[, axis, out, keepdims, initial, where])

别名 jax.numpy.max()

amin(a[, axis, out, keepdims, initial, where])

别名 jax.numpy.min()

angle(z[, deg])

返回复数值数或数组的角。

any(a[, axis, out, keepdims, where])

测试沿给定轴的数组元素是否评估为 True。

append(arr, values[, axis])

返回一个新数组,该数组在原数组的末尾附加了值。

apply_along_axis(func1d, axis, arr, *args, ...)

沿给定轴对1-D切片应用函数。

apply_over_axes(func, a, axes)

在多个轴上重复应用函数。

arange(start[, stop, step, dtype, device])

创建一个等间距值的数组。

arccos(x, /)

逐元素计算三角反余弦。

arccosh(x, /)

逐元素计算反双曲余弦。

arcsin(x, /)

逐元素计算反正弦。

arcsinh(x, /)

逐元素计算反双曲正弦。

arctan(x, /)

三角反切函数,逐元素计算。

arctan2(x1, x2, /)

逐元素计算 x1/x2 的反正切值,并正确选择象限。

arctanh(x, /)

逐元素计算反双曲正切。

argmax(a[, axis, out, keepdims])

返回沿某个轴的最大值的索引。

argmin(a[, axis, out, keepdims])

返回沿某个轴的最小值的索引。

argpartition(a, kth[, axis])

返回部分排序数组的索引。

argsort(a[, axis, kind, order, stable, ...])

返回排序数组的索引。

argwhere(a, *[, size, fill_value])

查找非零数组元素的索引

around(a[, decimals, out])

别名 jax.numpy.round()

array(object[, dtype, copy, order, ndmin, ...])

将对象转换为 JAX 数组。

array_equal(a1, a2[, equal_nan])

检查两个数组是否逐元素相等。

array_equiv(a1, a2)

检查两个数组是否逐元素相等。

array_repr(arr[, max_line_width, precision, ...])

返回数组的字符串表示形式。

array_split(ary, indices_or_sections[, axis])

将一个数组分割成子数组。

array_str(a[, max_line_width, precision, ...])

返回数组中数据的字符串表示形式。

asarray(a[, dtype, order, copy, device])

将对象转换为 JAX 数组。

asin(x, /)

别名 jax.numpy.arcsin()

asinh(x, /)

别名 jax.numpy.arcsinh()

astype(x, dtype, /, *[, copy, device])

将数组转换为指定的数据类型。

atan(x, /)

别名 jax.numpy.arctan()

atanh(x, /)

别名 jax.numpy.arctanh()

atan2(x1, x2, /)

别名 jax.numpy.arctan2()

atleast_1d()

将输入转换为至少具有一个维度的数组。

atleast_2d()

将输入视为至少具有两个维度的数组。

atleast_3d()

将输入视为至少具有三个维度的数组。

average()

计算沿指定轴的加权平均值。

bartlett(M)

返回 Bartlett 窗口。

bincount(x[, weights, minlength, length])

计算整数数组中每个值的出现次数。

bitwise_and

按元素计算按位与操作。

bitwise_count(x, /)

计算 x 中每个元素的绝对值的二进制表示中 1 的个数。

bitwise_invert(x, /)

按元素计算位反转,或按位 NOT。

bitwise_left_shift(x, y, /)

将整数的位向左移动。

bitwise_not(x, /)

按元素计算位反转,或按位 NOT。

bitwise_or

按元素计算按位或运算。

bitwise_right_shift(x1, x2, /)

别名 jax.numpy.right_shift()

bitwise_xor

逐元素计算按位异或运算。

blackman(M)

返回布莱克曼窗。

block(arrays)

从嵌套的块列表中组装一个nd数组。

bool_

bool 的别名

broadcast_arrays(*args)

将数组广播到公共形状。

broadcast_shapes()

将输入形状广播到公共输出形状。

broadcast_to(array, shape)

将数组广播到指定形状。

c_

沿最后一个轴连接切片、标量和类似数组的对象。

can_cast(from_, to[, casting])

如果根据类型转换规则可以在数据类型之间进行转换,则返回 True。

cbrt(x, /)

返回数组的立方根,逐元素进行。

cdouble

complex128 的别名

ceil(x, /)

将输入向上舍入到最接近的整数。

character()

所有字符串标量类型的抽象基类。

choose(a, choices[, out, mode])

从索引数组和选择数组列表中构造一个数组。

clip([arr, min, max, a, a_min, a_max])

将数组值裁剪到指定范围。

column_stack(tup)

将一维数组堆叠为二维数组的列。

complex_

complex128 的别名

complex128(x)

complex64(x)

complexfloating()

所有由浮点数组成的复数标量类型的抽象基类。

ComplexWarning

将复杂数据类型转换为实数数据类型时引发的警告。

compress(condition, a[, axis, size, ...])

使用布尔条件沿给定轴压缩数组。

concat(arrays, /, *[, axis])

沿现有轴连接一系列数组。

concatenate(arrays[, axis, dtype])

沿现有轴连接一系列数组。

conj(x, /)

返回逐元素的复共轭。

conjugate(x, /)

返回逐元素的复共轭。

convolve(a, v[, mode, precision, ...])

两个一维数组的卷积。

copy(a[, order])

返回数组的副本。

copysign(x1, x2, /)

x2 中每个元素的符号复制到 x1 中对应的元素。

corrcoef(x[, y, rowvar])

返回皮尔逊积矩相关系数。

correlate(a, v[, mode, precision, ...])

两个一维数组的关联。

cos(x, /)

计算输入中每个元素的三角余弦值。

cosh(x, /)

双曲余弦,逐元素计算。

count_nonzero(a[, axis, keepdims])

返回沿指定轴的非零元素的数量。

cov(m[, y, rowvar, bias, ddof, fweights, ...])

给定数据和权重,估计一个协方差矩阵。

cross(a, b[, axisa, axisb, axisc, axis])

返回两个(数组)向量的叉积。

csingle

complex64 的别名

cumprod(a[, axis, dtype, out])

沿轴的元素累积乘积。

cumsum(a[, axis, dtype, out])

沿轴的元素累计和。

cumulative_sum(x, /, *[, axis, dtype, ...])

沿数组轴的累积和。

deg2rad(x, /)

将角度从度转换为弧度。

degrees(x, /)

将角度从弧度转换为度数。

delete(arr, obj[, axis, assume_unique_indices])

从数组中删除条目或多个条目。

diag(v[, k])

返回指定的对角线或构造一个对角线数组。

diag_indices(n[, ndim])

返回用于访问多维数组主对角线的索引。

diag_indices_from(arr)

返回用于访问给定数组主对角线的索引。

diagflat(v[, k])

返回一个二维数组,其中扁平化的输入数组沿对角线排列。

diagonal(a[, offset, axis1, axis2])

返回数组的指定对角线。

diff(a[, n, axis, prepend, append])

计算沿给定轴的第 n 次离散差分。

digitize(x, bins[, right])

返回输入数组中每个值所属的箱子的索引。

divide(x1, x2, /)

别名 jax.numpy.true_divide()

divmod(x1, x2, /)

计算 x1 除以 x2 的整数商和余数,逐元素进行。

dot(a, b, *[, precision, preferred_element_type])

计算两个数组的点积。

double

float64 的别名

dsplit(ary, indices_or_sections)

将数组深度分割为子数组。

dstack(tup[, dtype])

按深度顺序堆叠数组(沿第三轴)。

dtype(dtype[, align, copy])

创建一个数据类型对象。

ediff1d(ary[, to_end, to_begin])

数组中连续元素之间的差异。

einsum()

爱因斯坦求和

einsum_path()

在不评估einsum的情况下,评估最优的收缩路径。

empty(shape[, dtype, device])

创建一个空数组。

empty_like(prototype[, dtype, shape, device])

创建一个与指定数组具有相同形状和数据类型的空数组。

equal(x, y, /)

逐元素返回 (x1 == x2)。

exp(x, /)

计算输入的逐元素指数。

exp2(x, /)

计算输入的逐元素以2为底的指数。

expand_dims(a, axis)

将长度为1的维度插入数组

expm1(x, /)

计算输入的每个元素的 exp(x)-1

extract(condition, arr, *[, size, fill_value])

返回满足条件的数组元素。

eye(N[, M, k, dtype, device])

创建一个方形或矩形的单位矩阵

fabs(x, /)

计算实值输入的逐元素绝对值。

fill_diagonal(a, val[, wrap, inplace])

返回一个数组的副本,其中对角线被覆盖。

finfo(dtype)

浮点类型的机器限制。

fix(x[, out])

将输入四舍五入到最接近的整数,趋向于零。

flatnonzero(a, *[, size, fill_value])

返回展平数组中非零元素的索引

flexible()

所有无预定义长度的标量类型的抽象基类。

flip(m[, axis])

沿给定轴反转数组元素的顺序。

fliplr(m)

沿轴1反转数组元素的顺序。

flipud(m)

沿轴 0 反转数组元素的顺序。

float_

float64 的别名

float_power(x, y, /)

计算元素级别的 x 底数的 y 指数。

float16(x)

float32(x)

float64(x)

floating()

所有浮点标量类型的抽象基类。

floor(x, /)

将输入向下舍入到最接近的整数。

floor_divide(x1, x2, /)

计算 x1 与 x2 的逐元素整除

fmax(x1, x2)

返回输入数组中逐元素的最大值。

fmin(x1, x2)

返回输入数组中逐元素的最小值。

fmod(x1, x2, /)

返回逐元素的除法余数。

frexp(x, /)

将 x 的元素分解为尾数和二的指数。

frombuffer(buffer[, dtype, count, offset])

将缓冲区解释为一维数组。

fromfile(*args, **kwargs)

未实现的 JAX 包装器用于 jnp.fromfile。

fromfunction(function, shape, *[, dtype])

通过在每个坐标上执行一个函数来构造一个数组。

fromiter(*args, **kwargs)

未实现的 JAX 包装器用于 jnp.fromiter。

frompyfunc(func, /, nin, nout, *[, identity])

从任意兼容JAX的标量函数创建一个JAX ufunc。

fromstring(string[, dtype, count])

从字符串中的文本数据初始化的新 1-D 数组。

from_dlpack(x, /, *[, device, copy])

通过 DLPack 构建 JAX 数组。

full(shape, fill_value[, dtype, device])

创建一个充满指定值的数组。

full_like(a, fill_value[, dtype, shape, device])

创建一个充满指定值的数组,其形状和数据类型与另一个数组相同。

gcd(x1, x2)

计算两个数组的最大公约数。

generic()

numpy 标量类型的基类。

geomspace(start, stop[, num, endpoint, ...])

返回在对数尺度上均匀分布的数字(几何级数)。

get_printoptions()

返回当前的打印选项。

gradient(f, *varargs[, axis, edge_order])

返回一个 N 维数组的梯度。

greater(x, y, /)

返回 (x1 > x2) 的元素级真值。

greater_equal(x, y, /)

返回 (x1 >= x2) 的元素级真值。

hamming(M)

返回汉明窗。

hanning(M)

返回汉宁窗。

heaviside(x1, x2, /)

计算 Heaviside 阶跃函数。

histogram(a[, bins, range, weights, density])

计算数据集的直方图。

histogram_bin_edges(a[, bins, range, weights])

用于计算 histogram 使用的箱子边缘的函数

histogram2d(x, y[, bins, range, weights, ...])

计算两个数据样本的二维直方图。

histogramdd(sample[, bins, range, weights, ...])

计算某些数据的多维直方图。

hsplit(ary, indices_or_sections)

将数组水平分割成子数组。

hstack(tup[, dtype])

按顺序水平堆叠数组(按列)。

hypot(x1, x2, /)

给定直角三角形的“边”,返回其斜边。

i0

第一类修正贝塞尔函数,0阶。

identity(n[, dtype])

创建一个单位矩阵

iinfo(int_type)

imag(val, /)

返回复数参数的虚部。

index_exp

构建数组索引元组的更好方法。

indices()

返回一个表示网格索引的数组。

inexact()

所有数值标量类型的抽象基类,其范围内的值可能具有(潜在的)不精确表示,例如浮点数。

inner(a, b, *[, precision, ...])

计算两个数组的内积。

insert(arr, obj, values[, axis])

在给定的轴上,在给定的索引之前插入值。

int_

int64 的别名

int16(x)

int32(x)

int64(x)

int8(x)

integer()

所有整数标量类型的抽象基类。

interp(x, xp, fp[, left, right, period])

针对单调递增样本点的一维线性插值。

intersect1d(ar1, ar2[, assume_unique, ...])

计算两个一维数组的集合交集。

invert(x, /)

按元素计算位反转,或按位 NOT。

isclose(a, b[, rtol, atol, equal_nan])

检查两个数组的元素是否在容差范围内近似相等。

iscomplex(x)

返回布尔数组,显示输入是否为复数。

iscomplexobj(x)

检查输入是否为复数或包含复数元素的数组。

isdtype(dtype, kind)

返回一个布尔值,指示提供的 dtype 是否属于指定类型。

isfinite(x, /)

测试逐元素是否为有限值(不是无穷大且不是非数值)。

isin(element, test_elements[, ...])

确定 element 中的元素是否出现在 test_elements 中。

isinf(x, /)

测试元素是否为正无穷或负无穷。

isnan(x, /)

逐元素测试 NaN 并返回结果为布尔数组。

isneginf(x, /[, out])

逐元素测试正无穷大,返回结果为布尔数组。

isposinf(x, /[, out])

逐元素测试正无穷大,返回结果为布尔数组。

isreal(x)

返回一个布尔数组,显示输入是否为实数。

isrealobj(x)

检查输入是否不是复数或包含复数元素的数组。

isscalar(element)

如果 element 的类型是标量类型,则返回 True。

issubdtype(arg1, arg2)

如果第一个参数在类型层次结构中低于或等于第二个参数,则返回 True。

iterable(y)

检查一个对象是否可以被迭代。

ix_(*args)

从 N 个一维序列返回一个多维网格(开放网格)。

kaiser(M, beta)

返回凯泽窗。

kron(a, b)

计算两个输入数组的Kronecker积。

lcm(x1, x2)

计算两个数组的最小公倍数。

ldexp(x1, x2, /)

返回 x1 * 2**x2,逐元素计算。

left_shift(x, y, /)

将整数的位向左移动。

less(x, y, /)

返回 (x1 < x2) 的元素级真值。

less_equal(x, y, /)

返回 (x1 <= x2) 的元素级真值。

lexsort(keys[, axis])

使用一系列键执行间接稳定排序。

linspace()

返回区间内的等间隔数字。

load(*args, **kwargs)

.npy, .npz 或序列化文件中加载数组或序列化对象。

log(x, /)

计算输入的逐元素自然对数。

log10(x, /)

计算 x 元素的以 10 为底的对数

log1p(x, /)

计算输入元素加一的对数,log(x+1)

log2(x, /)

计算 x 的逐元素以 2 为底的对数

logaddexp(x1, x2, /)

计算 log(exp(x1) + exp(x2)) 以避免溢出。

logaddexp2

以2为底的对数,输入的指数和。

logical_and

逐元素计算逻辑与操作。

logical_not(x, /)

计算 NOT x 的元素级真值。

logical_or

计算逐元素的逻辑或运算。

logical_xor

逐元素计算逻辑异或运算。

logspace(start, stop[, num, endpoint, base, ...])

返回在对数刻度上均匀间隔的数字。

mask_indices(*args, **kwargs)

返回一个掩码函数的索引,以访问 (n, n) 数组。

matmul(a, b, *[, precision, ...])

执行矩阵乘法。

matrix_transpose(x, /)

转置数组的最后两个维度。

max(a[, axis, out, keepdims, initial, where])

返回沿给定轴的数组元素的最大值。

maximum(x, y, /)

返回输入数组中逐元素的最大值。

mean(a[, axis, dtype, out, keepdims, where])

返回沿给定轴的数组元素的平均值。

median(a[, axis, out, overwrite_input, keepdims])

返回沿给定轴的数组元素的中位数。

meshgrid(*xi[, copy, sparse, indexing])

从坐标向量返回坐标矩阵的元组。

mgrid

返回密集的多维“网格”。

min(a[, axis, out, keepdims, initial, where])

返回沿给定轴的数组元素的最小值。

minimum(x, y, /)

返回输入数组中逐元素的最小值。

mod(x1, x2, /)

返回逐元素的除法余数。

modf(x, /[, out])

返回数组中每个元素的分数部分和整数部分。

moveaxis(a, source, destination)

将数组轴移动到新位置

multiply

逐元素相乘两个数组。

nan_to_num(x[, copy, nan, posinf, neginf])

将 NaN 替换为零,将无穷大替换为大的有限数(默认)

nanargmax(a[, axis, out, keepdims])

返回指定轴上最大值的索引,忽略

nanargmin(a[, axis, out, keepdims])

返回指定轴上最小值的索引,忽略

nancumprod(a[, axis, dtype, out])

沿轴的元素累积乘积,忽略NaN值。

nancumsum(a[, axis, dtype, out])

沿轴的元素累积和,忽略NaN值。

nanmax(a[, axis, out, keepdims, initial, where])

返回沿给定轴的数组元素的最大值,忽略 NaNs。

nanmean(a[, axis, dtype, out, keepdims, where])

返回沿给定轴的数组元素的均值,忽略 NaNs。

nanmedian(a[, axis, out, overwrite_input, ...])

返回沿给定轴的数组元素的中位数,忽略 NaNs。

nanmin(a[, axis, out, keepdims, initial, where])

返回沿给定轴的数组元素的最小值,忽略 NaNs。

nanpercentile(a, q[, axis, out, ...])

计算沿指定轴的数据百分位数,忽略 NaN 值。

nanprod(a[, axis, dtype, out, keepdims, ...])

返回沿给定轴的数组元素的乘积,忽略 NaNs。

nanquantile(a, q[, axis, out, ...])

计算沿指定轴的数据分位数,忽略 NaNs。

nanstd(a[, axis, dtype, out, ddof, ...])

计算沿指定轴的标准差,忽略 NaNs。

nansum(a[, axis, dtype, out, keepdims, ...])

返回沿给定轴的数组元素之和,忽略 NaNs。

nanvar(a[, axis, dtype, out, ddof, ...])

计算沿给定轴的数组元素的方差,忽略 NaNs。

ndarray

Array 的别名

ndim(a)

返回数组的维度数。

negative(x, /)

返回输入元素的负值。

nextafter(x, y, /)

返回 xy 方向的下一个浮点值。

nonzero(a, *[, size, fill_value])

返回数组中非零元素的索引。

not_equal(x, y, /)

逐元素返回 (x1 != x2)。

number()

所有数值标量类型的抽象基类。

object_

任何 Python 对象。

ogrid

返回开放的多维“网格”。

ones(shape[, dtype, device])

创建一个充满1的数组。

ones_like(a[, dtype, shape, device])

创建一个与给定数组具有相同形状和数据类型的全一数组。

outer(a, b[, out])

计算两个数组的外积。

packbits(a[, axis, bitorder])

将二值数组的元素打包成 uint8 数组中的位。

pad(array, pad_width[, mode])

填充数组。

partition(a, kth[, axis])

返回数组的部分排序副本。

percentile(a, q[, axis, out, ...])

计算数据沿指定轴的百分位数。

permute_dims(a, /, axes)

置换数组的轴/维度。

piecewise(x, condlist, funclist, *args, **kw)

在整个定义域上分段评估一个函数。

place(arr, mask, vals, *[, inplace])

基于掩码更新数组元素。

poly(seq_of_zeros)

返回给定根序列的多项式的系数。

polyadd(a1, a2)

返回两个多项式的和。

polyder(p[, m])

返回指定阶数多项式的导数的系数。

polydiv(u, v, *[, trim_leading_zeros])

返回多项式除法的商和余数。

polyfit(x, y, deg[, rcond, full, w, cov])

最小二乘多项式拟合数据。

polyint(p[, m, k])

返回多项式的指定阶数积分的系数。

polymul(a1, a2, *[, trim_leading_zeros])

返回两个多项式的乘积。

polysub(a1, a2)

返回两个多项式的差。

polyval(p, x, *[, unroll])

在特定值处计算多项式。

positive(x, /)

返回输入元素的正值。

pow(x1, x2, /)

第一个数组的元素按第二个数组的元素逐个求幂。

power(x1, x2, /)

第一个数组的元素按第二个数组的元素逐个求幂。

printoptions(*args, **kwargs)

用于设置打印选项的上下文管理器。

prod(a[, axis, dtype, out, keepdims, ...])

返回数组元素在给定轴上的乘积。

promote_types(a, b)

返回二元运算应将其参数转换为的类型。

ptp(a[, axis, out, keepdims])

返回给定轴上的峰峰值范围。

put(a, ind, v[, mode, inplace])

将元素放入指定索引的数组中。

quantile(a, q[, axis, out, overwrite_input, ...])

计算数据沿指定轴的分位数。

r_

沿第一个轴连接切片、标量和类似数组的对象。

rad2deg(x, /)

将角度从弧度转换为度数。

radians(x, /)

将角度从度转换为弧度。

ravel(a[, order])

将数组展平为1维形状。

ravel_multi_index(multi_index, dims[, mode, ...])

将多维索引转换为扁平索引。

real(val, /)

返回复数参数的实部。

reciprocal(x, /)

返回参数的倒数,逐元素进行。

remainder(x1, x2, /)

返回逐元素的除法余数。

repeat(a, repeats[, axis, total_repeat_length])

从重复元素构建数组。

reshape(a[, shape, order, newshape, copy])

返回数组的重新形状的副本。

resize(a, new_shape)

返回一个具有指定形状的新数组。

result_type(*args)

返回应用 NumPy 后的类型

right_shift(x1, x2, /)

x1 的位向右移动 x2 指定的数量。

rint(x, /)

将 x 的元素四舍五入到最近的整数

roll(a, shift[, axis])

沿指定轴滚动数组的元素。

rollaxis(a, axis[, start])

将指定轴滚动到给定位置。

roots(p, *[, strip_zeros])

返回给定系数 p 的多项式的根。

rot90(m[, k, axes])

在由轴指定的平面内将数组逆时针旋转90度。

round(a[, decimals, out])

将输入值四舍五入到给定的位数。

round_(a[, decimals, out])

将输入值四舍五入到给定的位数。

s_

构建数组索引元组的更好方法。

save(file, arr[, allow_pickle, fix_imports])

将数组保存为 NumPy .npy 格式的二进制文件。

savez(file, *args, **kwds)

将多个数组保存到一个未压缩的 .npz 格式文件中。

searchsorted(a, v[, side, sorter, method])

在一个已排序的数组中执行二分查找。

select(condlist, choicelist[, default])

根据一系列条件选择值。

set_printoptions([precision, threshold, ...])

设置打印选项。

setdiff1d(ar1, ar2[, assume_unique, size, ...])

计算两个一维数组的集合差。

setxor1d(ar1, ar2[, assume_unique, size, ...])

计算两个数组中元素的集合异或。

shape(a)

返回数组的形状。

sign(x, /)

返回输入的逐元素符号指示。

signbit(x, /)

返回元素级 True,其中符号位已设置(小于零)。

signedinteger()

所有有符号整数标量类型的抽象基类。

sin(x, /)

计算输入中每个元素的三角正弦值。

sinc(x, /)

返回归一化的sinc函数。

single

float32 的别名

sinh(x, /)

双曲正弦,逐元素计算。

size(a[, axis])

返回沿给定轴的元素数量。

sort(a[, axis, kind, order, stable, descending])

返回数组的排序副本。

sort_complex(a)

首先按实部排序,然后按虚部排序一个复杂的数组。

split(ary, indices_or_sections[, axis])

将一个数组分割成子数组。

sqrt(x, /)

返回数组的非负平方根,逐元素进行。

square(x, /)

返回输入的逐元素平方。

squeeze(a[, axis])

从数组中移除一个或多个长度为1的轴

stack(arrays[, axis, out, dtype])

沿新轴连接数组的序列。

std(a[, axis, dtype, out, ddof, keepdims, ...])

计算沿给定轴的标准差。

subtract(x, y, /)

逐元素减去参数。

sum(a[, axis, dtype, out, keepdims, ...])

数组元素在给定轴上的总和。

swapaxes(a, axis1, axis2)

交换数组的两个轴。

take(a, indices[, axis, out, mode, ...])

从数组中提取元素。

take_along_axis(arr, indices, axis[, mode, ...])

从数组中提取元素。

tan(x, /)

计算输入中每个元素的三角正切值。

tanh(x, /)

逐元素计算双曲正切。

tensordot(a, b[, axes, precision, ...])

计算两个N维数组的张量点积。

tile(A, reps)

通过重复 A 的次数来构造一个数组,次数由 reps 给出。

trace(a[, offset, axis1, axis2, dtype, out])

返回数组对角线上的元素之和。

trapezoid(y[, x, dx, axis])

使用复合梯形法则沿给定轴进行积分。

transpose(a[, axes])

返回一个 N 维数组的转置版本。

tri(N[, M, k, dtype])

返回一个数组,其中对角线及其下方为1,其他位置为0。

tril(m[, k])

返回数组的下三角部分。

tril_indices(n[, k, m])

返回大小为 (n, m) 的数组的下三角索引。

tril_indices_from(arr[, k])

返回给定数组的下三角索引。

trim_zeros(filt[, trim])

修剪输入数组的前导和/或尾随零。

triu(m[, k])

返回数组的上三角部分。

triu_indices(n[, k, m])

返回大小为 (n, m) 的数组的上三角索引。

triu_indices_from(arr[, k])

返回给定数组的上三角索引。

true_divide(x1, x2, /)

计算 x1 与 x2 的逐元素除法

trunc(x)

将输入四舍五入到最接近的整数,趋向于零。

ufunc(func, /, nin, nout, *[, name, nargs, ...])

对数组进行逐元素操作的通用函数。

uint

uint64 的别名

uint16(x)

uint32(x)

uint64(x)

uint8(x)

union1d(ar1, ar2, *[, size, fill_value])

计算两个一维数组的并集。

unique(ar[, return_index, return_inverse, ...])

返回数组中的唯一值。

unique_all(x, /, *[, size, fill_value])

从 x 中返回唯一值,以及索引、逆索引和计数。

unique_counts(x, /, *[, size, fill_value])

从 x 中返回唯一值及其计数。

unique_inverse(x, /, *[, size, fill_value])

从 x 中返回唯一值,以及索引、逆索引和计数。

unique_values(x, /, *[, size, fill_value])

从 x 中返回唯一值,以及索引、逆索引和计数。

unpackbits(a[, axis, count, bitorder])

将 uint8 数组的元素解包到二进制值的输出数组中。

unravel_index(indices, shape)

将平面索引转换为多维索引。

unstack(x, /, *[, axis])

沿着给定轴将数组分割成一系列数组。

unsignedinteger()

所有无符号整数标量类型的抽象基类。

unwrap(p[, discont, axis, period])

通过取大增量相对于周期的补码来进行解包。

vander(x[, N, increasing])

生成一个范德蒙矩阵。

var(a[, axis, dtype, out, ddof, keepdims, ...])

计算沿指定轴的方差。

vdot(a, b, *[, precision, ...])

对两个一维向量执行共轭乘法。

vecdot(x1, x2, /, *[, axis, precision, ...])

执行两个批量向量的共轭乘法。

vectorize(pyfunc, *[, excluded, signature])

定义一个带有广播功能的矢量化函数。

vsplit(ary, indices_or_sections)

将数组垂直分割成子数组。

vstack(tup[, dtype])

按顺序垂直堆叠数组(按行)。

where()

根据条件从两个数组中选择元素。

zeros(shape[, dtype, device])

创建一个充满零的数组。

zeros_like(a[, dtype, shape, device])

创建一个充满零的数组,其形状和数据类型与给定数组相同。

jax.numpy.fft#

fft(a[, n, axis, norm])

沿给定轴计算一维离散傅里叶变换。

fft2(a[, s, axes, norm])

沿着给定的轴计算二维离散傅里叶变换。

fftfreq(n[, d, dtype, device])

返回离散傅里叶变换的样本频率。

fftn(a[, s, axes, norm])

沿给定轴计算多维离散傅里叶变换。

fftshift(x[, axes])

将零频分量移到频谱中心。

hfft(a[, n, axis, norm])

计算具有厄米特对称性的数组的 1-D FFT。

ifft(a[, n, axis, norm])

计算一维逆离散傅里叶变换。

ifft2(a[, s, axes, norm])

计算二维逆离散傅里叶变换。

ifftn(a[, s, axes, norm])

计算多维逆离散傅里叶变换。

ifftshift(x[, axes])

fftshift 的逆操作。

ihfft(a[, n, axis, norm])

计算具有厄米特对称性的数组的 1-D 逆 FFT。

irfft(a[, n, axis, norm])

计算一个实数值的一维离散傅里叶逆变换。

irfft2(a[, s, axes, norm])

计算一个实数值的二维离散傅里叶逆变换。

irfftn(a[, s, axes, norm])

计算实数值多维逆离散傅里叶变换。

rfft(a[, n, axis, norm])

计算实值数组的一维离散傅里叶变换。

rfft2(a[, s, axes, norm])

计算实值数组的二维离散傅里叶变换。

rfftfreq(n[, d, dtype, device])

返回离散傅里叶变换的样本频率。

rfftn(a[, s, axes, norm])

计算实值数组的多维离散傅里叶变换。

jax.numpy.linalg#

cholesky(a, *[, upper])

计算矩阵的 Cholesky 分解。

cond(x[, p])

计算矩阵的条件数。

cross(x1, x2, /, *[, axis])

计算两个三维向量的叉积

det(a)

计算数组的行列式。

diagonal(x, /, *[, offset])

提取矩阵或矩阵堆的对角线。

eig(a)

计算方阵的特征值和特征向量。

eigh(a[, UPLO, symmetrize_input])

计算厄米矩阵的特征值和特征向量。

eigvals(a)

计算一般矩阵的特征值。

eigvalsh(a[, UPLO])

计算厄米矩阵的特征值。

inv(a)

返回一个方阵的逆矩阵

lstsq(a, b[, rcond, numpy_resid])

返回线性方程的最小二乘解。

matmul(x1, x2, /, *[, precision, ...])

执行矩阵乘法。

matrix_norm(x, /, *[, keepdims, ord])

计算矩阵或矩阵堆栈的范数。

matrix_power(a, n)

将一个方阵提升到一个整数幂。

matrix_rank(M[, rtol, tol])

计算矩阵的秩。

matrix_transpose(x, /)

转置矩阵或矩阵堆栈。

multi_dot(arrays, *[, precision])

高效计算数组序列之间的矩阵乘积。

norm(x[, ord, axis, keepdims])

计算矩阵或向量的范数。

outer(x1, x2, /)

计算两个一维数组的外积。

pinv(a[, rtol, hermitian, rcond])

计算矩阵的 (Moore-Penrose) 伪逆。

qr()

计算数组的QR分解

slogdet(a, *[, method])

计算数组的符号和(自然)对数行列式。

solve(a, b)

求解线性方程组

svd()

计算奇异值分解。

svdvals(x, /)

计算矩阵的奇异值。

tensordot(x1, x2, /, *[, axes, precision, ...])

计算两个N维数组的张量点积。

tensorinv(a[, ind])

计算数组的张量逆。

tensorsolve(a, b[, axes])

求解张量方程 a x = b 中的 x。

trace(x, /, *[, offset, dtype])

计算矩阵的迹。

vector_norm(x, /, *[, axis, keepdims, ord])

计算向量或一批向量的向量范数。

vecdot(x1, x2, /, *[, axis, precision, ...])

计算两个数组的(批量)向量共轭点积。

JAX 数组#

JAX 的 ndarray)是 JAX 中的核心数组对象:你可以将其视为 JAX 中与 numpy.ndarray 等效的对象。与 numpy.ndarray 类似,大多数用户不需要手动实例化 Array 对象,而是通过 jax.numpy 函数(如 array()arange()linspace() 等)来创建它们。

复制与序列化#

JAX Array 对象旨在在适当的情况下与 Python 标准库工具无缝协作。

使用内置的 copy 模块,当 copy.copy()copy.deepcopy() 遇到 Array 时,它等同于调用 copy() 方法,该方法将在与原始数组相同的设备上创建缓冲区的副本。这在跟踪/JIT 编译的代码中将正确工作,尽管在此上下文中编译器可能会省略复制操作。

当内置的 pickle 模块遇到一个 Array 时,它将通过一种紧凑的位表示形式进行序列化,类似于被pickle的 numpy.ndarray 对象。当反序列化时,结果将是一个新的 Array 对象 在默认设备上。这是因为通常情况下,序列化和反序列化可能发生在不同的运行时环境中,并且没有一种通用的方法可以将一个运行时的设备ID映射到另一个运行时的设备ID。如果在追踪/JIT编译的代码中使用 pickle,将会导致 ConcretizationTypeError

Python 数组 API 标准#

备注

在 JAX v0.4.32 之前,您必须 import jax.experimental.array_api 以启用 JAX 数组的数组 API。在 JAX v0.4.32 之后,导入此模块不再需要,并且会引发弃用警告。

从 JAX v0.4.32 开始,jax.Arrayjax.numpyPython 数组 API 标准 兼容。您可以通过 jax.Array.__array_namespace__() 访问数组 API 命名空间:

>>> def f(x):
...   nx = x.__array_namespace__()
...   return nx.sin(x) ** 2 + nx.cos(x) ** 2

>>> import jax.numpy as jnp
>>> x = jnp.arange(5)
>>> f(x).round()
Array([1., 1., 1., 1., 1.], dtype=float32)

JAX 在某些方面与标准有所不同,主要是因为 JAX 数组是不可变的,不支持就地更新。其中一些不兼容性正在通过 array-api-compat 模块解决。

更多信息,请参阅 Python 数组 API 标准 文档。