jax.numpy.linalg.norm

目录

jax.numpy.linalg.norm#

jax.numpy.linalg.norm(x, ord=None, axis=None, keepdims=False)[源代码][源代码]#

计算矩阵或向量的范数。

JAX 实现的 numpy.linalg.norm()

参数:
  • x (ArrayLike) – 将计算其范数的 N 维数组。

  • ord (int | str | None) – 指定要采用的范数类型。默认情况下,矩阵为Frobenius范数,向量为2-范数。其他选项请参见下面的注释。

  • axis (None | tuple[int, ...] | int) – 指定计算范数的轴的整数或整数序列。默认为 x 的所有轴。

  • keepdims (bool) – 如果为 True,输出数组将具有与输入相同的维度,减少的轴的大小将被替换为 ``1``(默认:False)。

返回:

包含指定范数的 x 的数组。

返回类型:

Array

备注

计算的范数类型取决于 ord 的值和被减少的轴的数量。

对于 向量范数 (即单轴缩减):

  • ord=None (默认) 计算2-范数

  • ord=inf 计算 max(abs(x))

  • ord=-inf 计算 min(abs(x))``

  • ord=0 计算 sum(x!=0)

  • 对于其他数值,计算 sum(abs(x) ** ord)**(1/ord)

对于 矩阵范数 (即两个轴的缩减):

  • ord='fro'ord=None (默认) 计算 Frobenius 范数

  • ord='nuc' 计算核范数,即奇异值之和

  • ord=1 计算 max(abs(x).sum(0))

  • ord=-1 计算 min(abs(x).sum(0))

  • ord=2 计算2-范数,即最大奇异值

  • ord=-2 计算最小的奇异值

示例

向量范数:

>>> x = jnp.array([3., 4., 12.])
>>> jnp.linalg.norm(x)
Array(13., dtype=float32)
>>> jnp.linalg.norm(x, ord=1)
Array(19., dtype=float32)
>>> jnp.linalg.norm(x, ord=0)
Array(3., dtype=float32)

矩阵范数:

>>> x = jnp.array([[1., 2., 3.],
...                [4., 5., 7.]])
>>> jnp.linalg.norm(x)  # Frobenius norm
Array(10.198039, dtype=float32)
>>> jnp.linalg.norm(x, ord='nuc')  # nuclear norm
Array(10.762535, dtype=float32)
>>> jnp.linalg.norm(x, ord=1)  # 1-norm
Array(10., dtype=float32)

批量向量范数:

>>> jnp.linalg.norm(x, axis=1)
Array([3.7416575, 9.486833 ], dtype=float32)