jax.numpy.eye

目录

jax.numpy.eye#

jax.numpy.eye(N, M=None, k=0, dtype=None, *, device=None)[源代码][源代码]#

创建一个方形或矩形的单位矩阵

JAX 实现的 numpy.eye()

参数:
  • N (DimSize) – 指定数组的第一维的整数。

  • M (DimSize | None) – 可选的整数,指定数组的第二维度;默认为与 N 相同的值。

  • k (int | ArrayLike) – 可选的整数,指定对角线的偏移量。使用正值表示上对角线,负值表示下对角线。默认值为零。

  • dtype (DTypeLike | None) – 可选的数据类型;默认为浮点型。

  • device (xc.Device | Sharding | None) – 可选的 DeviceSharding,创建的数组将被提交到该设备或分片。

返回:

形状为 (N, M) 的标识数组,如果未指定 M,则为 (N, N)

返回类型:

Array

参见

jax.numpy.identity(): 生成方形单位矩阵的简化API。

示例

一个简单的 3x3 单位矩阵:

>>> jnp.eye(3)
Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)

具有偏移对角线的整数单位矩阵:

>>> jnp.eye(3, k=1, dtype=int)
Array([[0, 1, 0],
       [0, 0, 1],
       [0, 0, 0]], dtype=int32)
>>> jnp.eye(3, k=-1, dtype=int)
Array([[0, 0, 0],
       [1, 0, 0],
       [0, 1, 0]], dtype=int32)

非方阵的单位矩阵:

>>> jnp.eye(3, 5, k=1)
Array([[0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.]], dtype=float32)