jax.nn.initializers.delta_orthogonal

jax.nn.initializers.delta_orthogonal#

jax.nn.initializers.delta_orthogonal(scale=1.0, column_axis=-1, dtype=<class 'jax.numpy.float64'>)[源代码]#

构建一个delta正交核的初始化器。

参数:
  • scale (RealNumeric) – 均匀分布的上界。

  • column_axis (int) – 包含应正交的列的轴。

  • dtype (DTypeLikeInexact) – 权重的默认数据类型。

返回:

一个 delta 正交初始化器。传递给初始化器的形状必须是 3D、4D 或 5D。

返回类型:

Initializer

示例:

>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.delta_orthogonal()
>>> initializer(jax.random.key(42), (3, 3, 3), jnp.float32)  
Array([[[ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ]],

       [[ 0.27858758, -0.7949833 , -0.53887904],
        [ 0.9120717 ,  0.04322892,  0.40774566],
        [-0.30085585, -0.6050892 ,  0.73712474]],

       [[ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ]]], dtype=float32)