jax.numpy
模块#
使用 jax.lax
中的原语实现了 NumPy API。
虽然 JAX 尽可能地遵循 NumPy API,但有时 JAX 无法完全遵循 NumPy。
值得注意的是,由于 JAX 数组是不可变的,NumPy API 中那些就地修改数组的函数无法在 JAX 中实现。然而,JAX 通常能够提供一个纯函数式的替代 API。例如,JAX 提供了替代的就地数组更新函数
x.at[i].set(y)
(参见ndarray.at
),而不是就地数组更新 (x[i] = y
)。同样地,一些 NumPy 函数在可能的情况下通常返回数组的视图(例如
transpose()
和reshape()
)。JAX 版本的此类函数将返回副本,尽管在使用jax.jit()
编译操作序列时,XLA 通常会优化掉这些副本。NumPy 在将值提升为
float64
类型时非常激进。JAX 在类型提升方面有时不那么激进(参见 类型提升)。一些 NumPy 例程的输出形状依赖于数据(例如
unique()
和nonzero()
)。由于 XLA 编译器要求在编译时知道数组形状,因此这些操作与 JIT 不兼容。为此,JAX 为这些函数添加了一个可选的size
参数,可以在使用 JIT 时静态指定该参数。
几乎所有适用的 NumPy 函数都在 jax.numpy
命名空间中实现;它们列在下面。
用于索引更新功能的辅助属性。 |
|
|
|
|
逐元素计算绝对值。 |
|
|
|
|
逐元素相加两个数组。 |
|
|
测试沿给定轴的所有数组元素是否评估为 True。 |
|
检查两个数组在容差范围内是否逐元素近似相等。 |
|
别名 |
|
别名 |
|
返回复数值数或数组的角。 |
|
测试沿给定轴的数组元素是否评估为 True。 |
|
返回一个新数组,该数组在原数组的末尾附加了值。 |
|
沿给定轴对1-D切片应用函数。 |
|
在多个轴上重复应用函数。 |
|
创建一个等间距值的数组。 |
|
逐元素计算三角反余弦。 |
|
逐元素计算反双曲余弦。 |
|
逐元素计算反正弦。 |
|
逐元素计算反双曲正弦。 |
|
三角反切函数,逐元素计算。 |
|
逐元素计算 |
|
逐元素计算反双曲正切。 |
|
返回沿某个轴的最大值的索引。 |
|
返回沿某个轴的最小值的索引。 |
|
返回部分排序数组的索引。 |
|
返回排序数组的索引。 |
|
查找非零数组元素的索引 |
|
|
|
将对象转换为 JAX 数组。 |
|
检查两个数组是否逐元素相等。 |
|
检查两个数组是否逐元素相等。 |
|
返回数组的字符串表示形式。 |
|
将一个数组分割成子数组。 |
|
返回数组中数据的字符串表示形式。 |
|
将对象转换为 JAX 数组。 |
|
|
|
|
|
将数组转换为指定的数据类型。 |
|
|
|
|
|
|
将输入转换为至少具有一个维度的数组。 |
|
将输入视为至少具有两个维度的数组。 |
|
将输入视为至少具有三个维度的数组。 |
|
|
计算沿指定轴的加权平均值。 |
|
返回 Bartlett 窗口。 |
|
计算整数数组中每个值的出现次数。 |
按元素计算按位与操作。 |
|
|
计算 |
|
按元素计算位反转,或按位 NOT。 |
|
将整数的位向左移动。 |
|
按元素计算位反转,或按位 NOT。 |
按元素计算按位或运算。 |
|
|
|
逐元素计算按位异或运算。 |
|
|
返回布莱克曼窗。 |
|
从嵌套的块列表中组装一个nd数组。 |
|
|
|
将数组广播到公共形状。 |
将输入形状广播到公共输出形状。 |
|
|
将数组广播到指定形状。 |
沿最后一个轴连接切片、标量和类似数组的对象。 |
|
|
如果根据类型转换规则可以在数据类型之间进行转换,则返回 True。 |
|
返回数组的立方根,逐元素进行。 |
|
|
|
将输入向上舍入到最接近的整数。 |
所有字符串标量类型的抽象基类。 |
|
|
从索引数组和选择数组列表中构造一个数组。 |
|
将数组值裁剪到指定范围。 |
|
将一维数组堆叠为二维数组的列。 |
|
|
|
|
|
|
所有由浮点数组成的复数标量类型的抽象基类。 |
|
将复杂数据类型转换为实数数据类型时引发的警告。 |
|
|
使用布尔条件沿给定轴压缩数组。 |
|
沿现有轴连接一系列数组。 |
|
沿现有轴连接一系列数组。 |
|
返回逐元素的复共轭。 |
|
返回逐元素的复共轭。 |
|
两个一维数组的卷积。 |
|
返回数组的副本。 |
|
将 |
|
返回皮尔逊积矩相关系数。 |
|
两个一维数组的关联。 |
|
计算输入中每个元素的三角余弦值。 |
|
双曲余弦,逐元素计算。 |
|
返回沿指定轴的非零元素的数量。 |
|
给定数据和权重,估计一个协方差矩阵。 |
|
返回两个(数组)向量的叉积。 |
|
|
|
沿轴的元素累积乘积。 |
|
沿轴的元素累计和。 |
|
沿数组轴的累积和。 |
|
将角度从度转换为弧度。 |
|
将角度从弧度转换为度数。 |
|
从数组中删除条目或多个条目。 |
|
返回指定的对角线或构造一个对角线数组。 |
|
返回用于访问多维数组主对角线的索引。 |
|
返回用于访问给定数组主对角线的索引。 |
|
返回一个二维数组,其中扁平化的输入数组沿对角线排列。 |
|
返回数组的指定对角线。 |
|
计算沿给定轴的第 n 次离散差分。 |
|
返回输入数组中每个值所属的箱子的索引。 |
|
|
|
计算 x1 除以 x2 的整数商和余数,逐元素进行。 |
|
计算两个数组的点积。 |
|
|
|
将数组深度分割为子数组。 |
|
按深度顺序堆叠数组(沿第三轴)。 |
|
创建一个数据类型对象。 |
|
数组中连续元素之间的差异。 |
|
爱因斯坦求和 |
在不评估einsum的情况下,评估最优的收缩路径。 |
|
|
创建一个空数组。 |
|
创建一个与指定数组具有相同形状和数据类型的空数组。 |
|
逐元素返回 (x1 == x2)。 |
|
计算输入的逐元素指数。 |
|
计算输入的逐元素以2为底的指数。 |
|
将长度为1的维度插入数组 |
|
计算输入的每个元素的 |
|
返回满足条件的数组元素。 |
|
创建一个方形或矩形的单位矩阵 |
|
计算实值输入的逐元素绝对值。 |
|
返回一个数组的副本,其中对角线被覆盖。 |
|
浮点类型的机器限制。 |
|
将输入四舍五入到最接近的整数,趋向于零。 |
|
返回展平数组中非零元素的索引 |
|
所有无预定义长度的标量类型的抽象基类。 |
|
沿给定轴反转数组元素的顺序。 |
|
沿轴1反转数组元素的顺序。 |
|
沿轴 0 反转数组元素的顺序。 |
|
|
|
计算元素级别的 |
|
|
|
|
|
|
|
所有浮点标量类型的抽象基类。 |
|
将输入向下舍入到最接近的整数。 |
|
计算 x1 与 x2 的逐元素整除 |
|
返回输入数组中逐元素的最大值。 |
|
返回输入数组中逐元素的最小值。 |
|
返回逐元素的除法余数。 |
|
将 x 的元素分解为尾数和二的指数。 |
|
将缓冲区解释为一维数组。 |
|
未实现的 JAX 包装器用于 jnp.fromfile。 |
|
通过在每个坐标上执行一个函数来构造一个数组。 |
|
未实现的 JAX 包装器用于 jnp.fromiter。 |
|
从任意兼容JAX的标量函数创建一个JAX ufunc。 |
|
从字符串中的文本数据初始化的新 1-D 数组。 |
|
通过 DLPack 构建 JAX 数组。 |
|
创建一个充满指定值的数组。 |
|
创建一个充满指定值的数组,其形状和数据类型与另一个数组相同。 |
|
计算两个数组的最大公约数。 |
|
numpy 标量类型的基类。 |
|
返回在对数尺度上均匀分布的数字(几何级数)。 |
返回当前的打印选项。 |
|
|
返回一个 N 维数组的梯度。 |
|
返回 (x1 > x2) 的元素级真值。 |
|
返回 (x1 >= x2) 的元素级真值。 |
|
返回汉明窗。 |
|
返回汉宁窗。 |
|
计算 Heaviside 阶跃函数。 |
|
计算数据集的直方图。 |
|
用于计算 histogram 使用的箱子边缘的函数 |
|
计算两个数据样本的二维直方图。 |
|
计算某些数据的多维直方图。 |
|
将数组水平分割成子数组。 |
|
按顺序水平堆叠数组(按列)。 |
|
给定直角三角形的“边”,返回其斜边。 |
第一类修正贝塞尔函数,0阶。 |
|
|
创建一个单位矩阵 |
|
|
|
返回复数参数的虚部。 |
构建数组索引元组的更好方法。 |
|
|
返回一个表示网格索引的数组。 |
|
所有数值标量类型的抽象基类,其范围内的值可能具有(潜在的)不精确表示,例如浮点数。 |
|
计算两个数组的内积。 |
|
在给定的轴上,在给定的索引之前插入值。 |
|
|
|
|
|
|
|
|
|
|
|
所有整数标量类型的抽象基类。 |
|
针对单调递增样本点的一维线性插值。 |
|
计算两个一维数组的集合交集。 |
|
按元素计算位反转,或按位 NOT。 |
|
检查两个数组的元素是否在容差范围内近似相等。 |
|
返回布尔数组,显示输入是否为复数。 |
|
检查输入是否为复数或包含复数元素的数组。 |
|
返回一个布尔值,指示提供的 dtype 是否属于指定类型。 |
|
测试逐元素是否为有限值(不是无穷大且不是非数值)。 |
|
确定 |
|
测试元素是否为正无穷或负无穷。 |
|
逐元素测试 NaN 并返回结果为布尔数组。 |
|
逐元素测试正无穷大,返回结果为布尔数组。 |
|
逐元素测试正无穷大,返回结果为布尔数组。 |
|
返回一个布尔数组,显示输入是否为实数。 |
|
检查输入是否不是复数或包含复数元素的数组。 |
|
如果 element 的类型是标量类型,则返回 True。 |
|
如果第一个参数在类型层次结构中低于或等于第二个参数,则返回 True。 |
|
检查一个对象是否可以被迭代。 |
|
从 N 个一维序列返回一个多维网格(开放网格)。 |
|
返回凯泽窗。 |
|
计算两个输入数组的Kronecker积。 |
|
计算两个数组的最小公倍数。 |
|
返回 x1 * 2**x2,逐元素计算。 |
|
将整数的位向左移动。 |
|
返回 (x1 < x2) 的元素级真值。 |
|
返回 (x1 <= x2) 的元素级真值。 |
|
使用一系列键执行间接稳定排序。 |
|
返回区间内的等间隔数字。 |
|
从 |
|
计算输入的逐元素自然对数。 |
|
计算 x 元素的以 10 为底的对数 |
|
计算输入元素加一的对数, |
|
计算 x 的逐元素以 2 为底的对数 |
|
计算 |
以2为底的对数,输入的指数和。 |
|
逐元素计算逻辑与操作。 |
|
|
计算 NOT x 的元素级真值。 |
计算逐元素的逻辑或运算。 |
|
逐元素计算逻辑异或运算。 |
|
|
返回在对数刻度上均匀间隔的数字。 |
|
返回一个掩码函数的索引,以访问 (n, n) 数组。 |
|
执行矩阵乘法。 |
|
转置数组的最后两个维度。 |
|
返回沿给定轴的数组元素的最大值。 |
|
返回输入数组中逐元素的最大值。 |
|
返回沿给定轴的数组元素的平均值。 |
|
返回沿给定轴的数组元素的中位数。 |
|
从坐标向量返回坐标矩阵的元组。 |
返回密集的多维“网格”。 |
|
|
返回沿给定轴的数组元素的最小值。 |
|
返回输入数组中逐元素的最小值。 |
|
返回逐元素的除法余数。 |
|
返回数组中每个元素的分数部分和整数部分。 |
|
将数组轴移动到新位置 |
逐元素相乘两个数组。 |
|
|
将 NaN 替换为零,将无穷大替换为大的有限数(默认) |
|
返回指定轴上最大值的索引,忽略 |
|
返回指定轴上最小值的索引,忽略 |
|
沿轴的元素累积乘积,忽略NaN值。 |
|
沿轴的元素累积和,忽略NaN值。 |
|
返回沿给定轴的数组元素的最大值,忽略 NaNs。 |
|
返回沿给定轴的数组元素的均值,忽略 NaNs。 |
|
返回沿给定轴的数组元素的中位数,忽略 NaNs。 |
|
返回沿给定轴的数组元素的最小值,忽略 NaNs。 |
|
计算沿指定轴的数据百分位数,忽略 NaN 值。 |
|
返回沿给定轴的数组元素的乘积,忽略 NaNs。 |
|
计算沿指定轴的数据分位数,忽略 NaNs。 |
|
计算沿指定轴的标准差,忽略 NaNs。 |
|
返回沿给定轴的数组元素之和,忽略 NaNs。 |
|
计算沿给定轴的数组元素的方差,忽略 NaNs。 |
|
|
|
返回数组的维度数。 |
|
返回输入元素的负值。 |
|
返回 |
|
返回数组中非零元素的索引。 |
|
逐元素返回 (x1 != x2)。 |
|
所有数值标量类型的抽象基类。 |
任何 Python 对象。 |
|
返回开放的多维“网格”。 |
|
|
创建一个充满1的数组。 |
|
创建一个与给定数组具有相同形状和数据类型的全一数组。 |
|
计算两个数组的外积。 |
|
将二值数组的元素打包成 uint8 数组中的位。 |
|
填充数组。 |
|
返回数组的部分排序副本。 |
|
计算数据沿指定轴的百分位数。 |
|
置换数组的轴/维度。 |
|
在整个定义域上分段评估一个函数。 |
|
基于掩码更新数组元素。 |
|
返回给定根序列的多项式的系数。 |
|
返回两个多项式的和。 |
|
返回指定阶数多项式的导数的系数。 |
|
返回多项式除法的商和余数。 |
|
最小二乘多项式拟合数据。 |
|
返回多项式的指定阶数积分的系数。 |
|
返回两个多项式的乘积。 |
|
返回两个多项式的差。 |
|
在特定值处计算多项式。 |
|
返回输入元素的正值。 |
|
第一个数组的元素按第二个数组的元素逐个求幂。 |
|
第一个数组的元素按第二个数组的元素逐个求幂。 |
|
用于设置打印选项的上下文管理器。 |
|
返回数组元素在给定轴上的乘积。 |
|
返回二元运算应将其参数转换为的类型。 |
|
返回给定轴上的峰峰值范围。 |
|
将元素放入指定索引的数组中。 |
|
计算数据沿指定轴的分位数。 |
沿第一个轴连接切片、标量和类似数组的对象。 |
|
|
将角度从弧度转换为度数。 |
|
将角度从度转换为弧度。 |
|
将数组展平为1维形状。 |
|
将多维索引转换为扁平索引。 |
|
返回复数参数的实部。 |
|
返回参数的倒数,逐元素进行。 |
|
返回逐元素的除法余数。 |
|
从重复元素构建数组。 |
|
返回数组的重新形状的副本。 |
|
返回一个具有指定形状的新数组。 |
|
返回应用 NumPy 后的类型 |
|
将 |
|
将 x 的元素四舍五入到最近的整数 |
|
沿指定轴滚动数组的元素。 |
|
将指定轴滚动到给定位置。 |
|
返回给定系数 |
|
在由轴指定的平面内将数组逆时针旋转90度。 |
|
将输入值四舍五入到给定的位数。 |
|
将输入值四舍五入到给定的位数。 |
构建数组索引元组的更好方法。 |
|
|
将数组保存为 NumPy |
|
将多个数组保存到一个未压缩的 |
|
在一个已排序的数组中执行二分查找。 |
|
根据一系列条件选择值。 |
|
设置打印选项。 |
|
计算两个一维数组的集合差。 |
|
计算两个数组中元素的集合异或。 |
|
返回数组的形状。 |
|
返回输入的逐元素符号指示。 |
|
返回元素级 True,其中符号位已设置(小于零)。 |
所有有符号整数标量类型的抽象基类。 |
|
|
计算输入中每个元素的三角正弦值。 |
|
返回归一化的sinc函数。 |
|
|
|
双曲正弦,逐元素计算。 |
|
返回沿给定轴的元素数量。 |
|
返回数组的排序副本。 |
|
首先按实部排序,然后按虚部排序一个复杂的数组。 |
|
将一个数组分割成子数组。 |
|
返回数组的非负平方根,逐元素进行。 |
|
返回输入的逐元素平方。 |
|
从数组中移除一个或多个长度为1的轴 |
|
沿新轴连接数组的序列。 |
|
计算沿给定轴的标准差。 |
|
逐元素减去参数。 |
|
数组元素在给定轴上的总和。 |
|
交换数组的两个轴。 |
|
从数组中提取元素。 |
|
从数组中提取元素。 |
|
计算输入中每个元素的三角正切值。 |
|
逐元素计算双曲正切。 |
|
计算两个N维数组的张量点积。 |
|
通过重复 A 的次数来构造一个数组,次数由 reps 给出。 |
|
返回数组对角线上的元素之和。 |
|
使用复合梯形法则沿给定轴进行积分。 |
|
返回一个 N 维数组的转置版本。 |
|
返回一个数组,其中对角线及其下方为1,其他位置为0。 |
|
返回数组的下三角部分。 |
|
返回大小为 |
|
返回给定数组的下三角索引。 |
|
修剪输入数组的前导和/或尾随零。 |
|
返回数组的上三角部分。 |
|
返回大小为 |
|
返回给定数组的上三角索引。 |
|
计算 x1 与 x2 的逐元素除法 |
|
将输入四舍五入到最接近的整数,趋向于零。 |
|
对数组进行逐元素操作的通用函数。 |
|
|
|
|
|
|
|
|
|
|
|
计算两个一维数组的并集。 |
|
返回数组中的唯一值。 |
|
从 x 中返回唯一值,以及索引、逆索引和计数。 |
|
从 x 中返回唯一值及其计数。 |
|
从 x 中返回唯一值,以及索引、逆索引和计数。 |
|
从 x 中返回唯一值,以及索引、逆索引和计数。 |
|
将 uint8 数组的元素解包到二进制值的输出数组中。 |
|
将平面索引转换为多维索引。 |
|
沿着给定轴将数组分割成一系列数组。 |
所有无符号整数标量类型的抽象基类。 |
|
|
通过取大增量相对于周期的补码来进行解包。 |
|
生成一个范德蒙矩阵。 |
|
计算沿指定轴的方差。 |
|
对两个一维向量执行共轭乘法。 |
|
执行两个批量向量的共轭乘法。 |
|
定义一个带有广播功能的矢量化函数。 |
|
将数组垂直分割成子数组。 |
|
按顺序垂直堆叠数组(按行)。 |
|
根据条件从两个数组中选择元素。 |
|
创建一个充满零的数组。 |
|
创建一个充满零的数组,其形状和数据类型与给定数组相同。 |
jax.numpy.fft#
|
沿给定轴计算一维离散傅里叶变换。 |
|
沿着给定的轴计算二维离散傅里叶变换。 |
|
返回离散傅里叶变换的样本频率。 |
|
沿给定轴计算多维离散傅里叶变换。 |
|
将零频分量移到频谱中心。 |
|
计算具有厄米特对称性的数组的 1-D FFT。 |
|
计算一维逆离散傅里叶变换。 |
|
计算二维逆离散傅里叶变换。 |
|
计算多维逆离散傅里叶变换。 |
|
fftshift 的逆操作。 |
|
计算具有厄米特对称性的数组的 1-D 逆 FFT。 |
|
计算一个实数值的一维离散傅里叶逆变换。 |
|
计算一个实数值的二维离散傅里叶逆变换。 |
|
计算实数值多维逆离散傅里叶变换。 |
|
计算实值数组的一维离散傅里叶变换。 |
|
计算实值数组的二维离散傅里叶变换。 |
|
返回离散傅里叶变换的样本频率。 |
|
计算实值数组的多维离散傅里叶变换。 |
jax.numpy.linalg#
|
计算矩阵的 Cholesky 分解。 |
|
计算矩阵的条件数。 |
|
计算两个三维向量的叉积 |
|
计算数组的行列式。 |
|
提取矩阵或矩阵堆的对角线。 |
|
计算方阵的特征值和特征向量。 |
|
计算厄米矩阵的特征值和特征向量。 |
|
计算一般矩阵的特征值。 |
|
计算厄米矩阵的特征值。 |
|
返回一个方阵的逆矩阵 |
|
返回线性方程的最小二乘解。 |
|
执行矩阵乘法。 |
|
计算矩阵或矩阵堆栈的范数。 |
|
将一个方阵提升到一个整数幂。 |
|
计算矩阵的秩。 |
|
转置矩阵或矩阵堆栈。 |
|
高效计算数组序列之间的矩阵乘积。 |
|
计算矩阵或向量的范数。 |
|
计算两个一维数组的外积。 |
|
计算矩阵的 (Moore-Penrose) 伪逆。 |
|
计算数组的QR分解 |
|
计算数组的符号和(自然)对数行列式。 |
|
求解线性方程组 |
|
计算奇异值分解。 |
|
计算矩阵的奇异值。 |
|
计算两个N维数组的张量点积。 |
|
计算数组的张量逆。 |
|
求解张量方程 a x = b 中的 x。 |
|
计算矩阵的迹。 |
|
计算向量或一批向量的向量范数。 |
|
计算两个数组的(批量)向量共轭点积。 |
JAX 数组#
JAX 的 ndarray
)是 JAX 中的核心数组对象:你可以将其视为 JAX 中与 numpy.ndarray
等效的对象。与 numpy.ndarray
类似,大多数用户不需要手动实例化 Array
对象,而是通过 jax.numpy
函数(如 array()
、arange()
、linspace()
等)来创建它们。
复制与序列化#
JAX Array
对象旨在在适当的情况下与 Python 标准库工具无缝协作。
使用内置的 copy
模块,当 copy.copy()
或 copy.deepcopy()
遇到 Array
时,它等同于调用 copy()
方法,该方法将在与原始数组相同的设备上创建缓冲区的副本。这在跟踪/JIT 编译的代码中将正确工作,尽管在此上下文中编译器可能会省略复制操作。
当内置的 pickle
模块遇到一个 Array
时,它将通过一种紧凑的位表示形式进行序列化,类似于被pickle的 numpy.ndarray
对象。当反序列化时,结果将是一个新的 Array
对象 在默认设备上。这是因为通常情况下,序列化和反序列化可能发生在不同的运行时环境中,并且没有一种通用的方法可以将一个运行时的设备ID映射到另一个运行时的设备ID。如果在追踪/JIT编译的代码中使用 pickle
,将会导致 ConcretizationTypeError
。
Python 数组 API 标准#
备注
在 JAX v0.4.32 之前,您必须 import jax.experimental.array_api
以启用 JAX 数组的数组 API。在 JAX v0.4.32 之后,导入此模块不再需要,并且会引发弃用警告。
从 JAX v0.4.32 开始,jax.Array
和 jax.numpy
与 Python 数组 API 标准 兼容。您可以通过 jax.Array.__array_namespace__()
访问数组 API 命名空间:
>>> def f(x):
... nx = x.__array_namespace__()
... return nx.sin(x) ** 2 + nx.cos(x) ** 2
>>> import jax.numpy as jnp
>>> x = jnp.arange(5)
>>> f(x).round()
Array([1., 1., 1., 1., 1.], dtype=float32)
JAX 在某些方面与标准有所不同,主要是因为 JAX 数组是不可变的,不支持就地更新。其中一些不兼容性正在通过 array-api-compat 模块解决。
更多信息,请参阅 Python 数组 API 标准 文档。