jax.numpy.trace#
- jax.numpy.trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None)[源代码][源代码]#
返回数组对角线上的元素之和。
LAX-backend 实现的
numpy.trace()
。原始文档字符串如下。
如果 a 是二维的,则返回其对角线上的和,偏移量为给定的值,即所有元素
a[i,i+offset]
的和。如果 a 的维度超过两个,那么由 axis1 和 axis2 指定的轴将用于确定其迹将被返回的 2-D 子数组。结果数组的形状与 a 相同,只是去掉了 axis1 和 axis2。
- 参数:
a (array_like) – 输入数组,从中提取对角线。
offset (int, optional) – 对角线相对于主对角线的偏移量。可以是正值或负值。默认为 0。
axis1 (int, optional) – 用作二维子数组的第一和第二轴,从中应提取对角线。默认值是 a 的前两个轴。
axis2 (int, optional) – 用作二维子数组的第一和第二轴,从中应提取对角线。默认值是 a 的前两个轴。
dtype (dtype, optional) – 确定返回数组的数据类型以及元素求和的累加器的数据类型。如果 dtype 的值为 None 且 a 是精度低于默认整数精度的整数类型,则使用默认整数精度。否则,精度与 a 相同。
out (None)
- 返回:
sum_along_diagonals – 如果 a 是二维的,则返回对角线上的和。如果 a 有更大的维度,则返回沿对角线的和的数组。
- 返回类型:
ndarray