jax.numpy.identity

目录

jax.numpy.identity#

jax.numpy.identity(n, dtype=None)[源代码][源代码]#

创建一个单位矩阵

JAX 实现的 numpy.identity()

参数:
  • n (DimSize) – 整数,指定每个数组维度的尺寸。

  • dtype (DTypeLike | None) – 可选的数据类型;默认为浮点型。

返回:

形状为 (n, n) 的单位矩阵。

返回类型:

Array

参见

jax.numpy.eye(): 非方阵和/或偏移单位矩阵。

示例

一个简单的 3x3 单位矩阵:

>>> jnp.identity(3)
Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)

一个2x2的整数单位矩阵:

>>> jnp.identity(2, dtype=int)
Array([[1, 0],
       [0, 1]], dtype=int32)