jax.Array

jax.Array#

class jax.Array#

JAX 的数组基类

jax.Array 是用于 JAX 数组和追踪器实例检查和类型注解的公共接口。它的主要应用在于实例检查和类型注解;例如:

x = jnp.arange(5)
isinstance(x, jax.Array)  # returns True both inside and outside traced functions.

def f(x: Array) -> Array:  # type annotations are valid for traced and non-traced types.
  return x

jax.Array 不应直接用于数组的创建;相反,您应该使用 jax.numpy 中提供的数组创建例程,例如 jax.numpy.array()jax.numpy.zeros()jax.numpy.ones()jax.numpy.full()jax.numpy.arange() 等。

__init__()#

方法

__init__()

addressable_data(index)

返回特定索引处可寻址数据的数组。

all([axis, out, keepdims, where])

测试沿给定轴的所有数组元素是否评估为 True。

any([axis, out, keepdims, where])

测试沿给定轴的任何数组元素是否评估为 True。

argmax([axis, out, keepdims])

返回最大值的索引。

argmin([axis, out, keepdims])

返回最小值的索引。

argpartition(kth[, axis])

返回部分排序数组的索引。

argsort([axis, kind, order, stable, descending])

返回对数组进行排序的索引。

astype(dtype[, copy, device])

复制数组并转换为指定的数据类型。

choose(choices[, out, mode])

从多个数组的元素中构建一个数组。

clip([min, max])

返回一个数组,其值被限制在指定范围内。

compress(condition[, axis, out, size, ...])

返回沿给定轴的此数组的选择切片。

conj()

返回数组的复共轭。

conjugate()

返回数组的复共轭。

copy()

返回数组的副本。

copy_to_host_async()

异步地将 Array 复制到主机。

cumprod([axis, dtype, out])

返回数组的累积乘积。

cumsum([axis, dtype, out])

返回数组的累计和。

diagonal([offset, axis1, axis2])

返回数组中指定的对角线。

dot(b, *[, precision, preferred_element_type])

计算两个数组的点积。

flatten([order])

将数组展平为1维形状。

item(*args)

将数组的一个元素复制到一个标准的 Python 标量并返回它。

max([axis, out, keepdims, initial, where])

返回沿给定轴的数组元素的最大值。

mean([axis, dtype, out, keepdims, where])

返回沿给定轴的数组元素的平均值。

min([axis, out, keepdims, initial, where])

返回沿给定轴的数组元素的最小值。

nonzero(*[, fill_value, size])

返回数组中非零元素的索引。

prod([axis, dtype, out, keepdims, initial, ...])

返回数组元素在给定轴上的乘积。

ptp([axis, out, keepdims])

返回给定轴上的峰峰值范围。

ravel([order])

将数组展平为1维形状。

repeat(repeats[, axis, total_repeat_length])

从重复元素构建数组。

reshape(*args[, order])

返回一个包含相同数据但形状不同的新数组。

round([decimals, out])

将数组元素四舍五入到给定的小数位。

searchsorted(v[, side, sorter, method])

在一个已排序的数组中执行二分查找。

sort([axis, kind, order, stable, descending])

返回数组的排序副本。

squeeze([axis])

从数组中移除一个或多个长度为1的轴。

std([axis, dtype, out, ddof, keepdims, ...])

计算沿给定轴的标准差。

sum([axis, dtype, out, keepdims, initial, ...])

数组元素在给定轴上的总和。

swapaxes(axis1, axis2)

交换数组的两个轴。

take(indices[, axis, out, mode, ...])

从数组中提取元素。

to_device(device, *[, stream])

返回指定设备上的数组副本

trace([offset, axis1, axis2, dtype, out])

返回对角线上的元素之和。

transpose(*args)

返回数组的一个转置副本。

var([axis, dtype, out, ddof, keepdims, ...])

计算沿指定轴的方差。

view([dtype, type])

返回数组的按位复制,视作新的数据类型。

属性

T

计算全轴数组转置。

addressable_shards

可寻址分片列表。

at

用于索引更新功能的辅助属性。

device

与数组API兼容的设备属性。

dtype

数组的 数据类型 (numpy.dtype)。

flat

使用 flatten() 代替。

global_shards

全局分片列表。

imag

返回数组的虚部。

is_fully_addressable

这个数组是完全可寻址的吗?

is_fully_replicated

这个数组是完全复制的吗?

itemsize

一个数组元素的字节长度。

mT

计算(批量)矩阵转置。

nbytes

数组元素消耗的总字节数。

ndim

数组的维度数量。

real

返回数组的实部。

shape

数组的形状。

sharding

数组的切片。

size

数组中元素的总数。