jax.numpy.tri

目录

jax.numpy.tri#

jax.numpy.tri(N, M=None, k=0, dtype=None)[源代码][源代码]#

返回一个数组,其中对角线及其下方为1,其他位置为0。

JAX 实现的 numpy.tri()

参数:
  • N (int) – int. 返回数组的行维度。

  • M (int | None) – 可选, int. 返回数组的列的维度。如果未指定,则 M = N

  • k (int) – 可选, int, 默认=0。指定数组中填充为1的次对角线。k=0 指主对角线, k<0 指主对角线下方的次对角线, k>0 指主对角线上方的次对角线。

  • dtype (DTypeLike | None) – 可选,返回数组的类型。默认类型为浮点型。

返回:

一个形状为 (N, M) 的数组,其中包含由 k 指定的次对角线下方的下三角元素被设置为1,其他地方为0。

返回类型:

Array

参见

示例

>>> jnp.tri(3)
Array([[1., 0., 0.],
       [1., 1., 0.],
       [1., 1., 1.]], dtype=float32)

M 不等于 N 时:

>>> jnp.tri(3, 4)
Array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.]], dtype=float32)

k>0 时:

>>> jnp.tri(3, k=1)
Array([[1., 1., 0.],
       [1., 1., 1.],
       [1., 1., 1.]], dtype=float32)

k<0 时:

>>> jnp.tri(3, 4, k=-1)
Array([[0., 0., 0., 0.],
       [1., 0., 0., 0.],
       [1., 1., 0., 0.]], dtype=float32)