jax.nn.initializers.delta_orthogonal#
- jax.nn.initializers.delta_orthogonal(scale=1.0, column_axis=-1, dtype=<class 'jax.numpy.float64'>)[源代码]#
构建一个delta正交核的初始化器。
- 参数:
scale (RealNumeric) – 均匀分布的上界。
column_axis (int) – 包含应正交的列的轴。
dtype (DTypeLikeInexact) – 权重的默认数据类型。
- 返回:
一个 delta 正交初始化器。传递给初始化器的形状必须是 3D、4D 或 5D。
- 返回类型:
Initializer
示例:
>>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.delta_orthogonal() >>> initializer(jax.random.key(42), (3, 3, 3), jnp.float32) Array([[[ 0. , 0. , 0. ], [ 0. , 0. , 0. ], [ 0. , 0. , 0. ]], [[ 0.27858758, -0.7949833 , -0.53887904], [ 0.9120717 , 0.04322892, 0.40774566], [-0.30085585, -0.6050892 , 0.73712474]], [[ 0. , 0. , 0. ], [ 0. , 0. , 0. ], [ 0. , 0. , 0. ]]], dtype=float32)