jax.nn.initializers.variance_scaling#
- jax.nn.initializers.variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, batch_axis=(), dtype=<class 'jax.numpy.float64'>)[源代码]#
初始化器,根据权重张量的形状调整其尺度。
使用
distribution="truncated_normal"
或distribution="normal"
,样本从均值为零且标准差(如有截断则为截断后)为 \(\sqrt{\frac{scale}{n}}\) 的(截断)正态分布中抽取,其中 n 是:权重张量中输入单元的数量,如果
mode="fan_in"
,输出单元的数量,如果
mode="fan_out"
,或者输入和输出单元数量的平均值,如果
mode="fan_avg"
。
此初始化器可以通过
in_axis
、out_axis
和batch_axis
进行配置,以适用于一般的卷积层或密集层;不在这些参数中的轴被假定为“感受野”(卷积核的空间轴)。使用
distribution="truncated_normal"
时,样本的绝对值在缩放前会被截断在2个标准差处。使用
distribution="uniform"
时,样本从以下分布中抽取:一个均匀的间隔,如果 dtype 是实数,或者
一个均匀的圆盘,如果 dtype 是复数,
其均值为零,标准差为 \(\sqrt{\frac{scale}{n}}\),其中 n 如上所述定义。
- 参数:
scale (RealNumeric) – 缩放因子(正浮点数)。
mode (Literal['fan_in'] | Literal['fan_out'] | Literal['fan_avg']) –
"fan_in"
,"fan_out"
, 和"fan_avg"
之一。distribution (Literal['truncated_normal'] | Literal['normal'] | Literal['uniform']) – 要使用的随机分布。可以是
"truncated_normal"
、"normal"
或"uniform"
之一。batch_axis (Sequence[int]) – 权重数组中应忽略的轴或轴序列。
dtype (DTypeLikeInexact) – 权重的数据类型。
- 返回类型:
Initializer