jax.numpy.tri#
- jax.numpy.tri(N, M=None, k=0, dtype=None)[源代码][源代码]#
返回一个数组,其中对角线及其下方为1,其他位置为0。
JAX 实现的
numpy.tri()
- 参数:
- 返回:
一个形状为
(N, M)
的数组,其中包含由k
指定的次对角线下方的下三角元素被设置为1,其他地方为0。- 返回类型:
参见
jax.numpy.tril()
: 返回数组的下三角部分。jax.numpy.triu()
: 返回数组的上三角部分。
示例
>>> jnp.tri(3) Array([[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]], dtype=float32)
当
M
不等于N
时:>>> jnp.tri(3, 4) Array([[1., 0., 0., 0.], [1., 1., 0., 0.], [1., 1., 1., 0.]], dtype=float32)
当
k>0
时:>>> jnp.tri(3, k=1) Array([[1., 1., 0.], [1., 1., 1.], [1., 1., 1.]], dtype=float32)
当
k<0
时:>>> jnp.tri(3, 4, k=-1) Array([[0., 0., 0., 0.], [1., 0., 0., 0.], [1., 1., 0., 0.]], dtype=float32)