jax.numpy.identity#
- jax.numpy.identity(n, dtype=None)[源代码][源代码]#
创建一个单位矩阵
JAX 实现的
numpy.identity()
。- 参数:
n (DimSize) – 整数,指定每个数组维度的尺寸。
dtype (DTypeLike | None) – 可选的数据类型;默认为浮点型。
- 返回:
形状为
(n, n)
的单位矩阵。- 返回类型:
参见
jax.numpy.eye()
: 非方阵和/或偏移单位矩阵。示例
一个简单的 3x3 单位矩阵:
>>> jnp.identity(3) Array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32)
一个2x2的整数单位矩阵:
>>> jnp.identity(2, dtype=int) Array([[1, 0], [0, 1]], dtype=int32)