jax.example_libraries.stax
模块#
Stax 是一个从小到大构建的灵活的小型神经网络规范库。
你可能并不打算导入这个模块!Stax 仅作为示例库使用。JAX 有许多其他功能更全面的神经网络库,包括来自 Google 的 Flax 和来自 DeepMind 的 Haiku。
- jax.example_libraries.stax.AvgPool(window_shape, strides=None, padding='VALID', spec=None)[源代码]#
池化层的层构造函数。
- jax.example_libraries.stax.BatchNorm(axis=(0, 1, 2), epsilon=1e-05, center=True, scale=True, beta_init=<function zeros>, gamma_init=<function ones>)[源代码][源代码]#
批量归一化层的层构造函数。
- jax.example_libraries.stax.Conv(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)#
通用卷积层的层构造函数。
- jax.example_libraries.stax.Conv1DTranspose(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)#
通用转置卷积层的层构造函数。
- jax.example_libraries.stax.ConvTranspose(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)#
通用转置卷积层的层构造函数。
- jax.example_libraries.stax.Dense(out_dim, W_init=<function variance_scaling.<locals>.init>, b_init=<function normal.<locals>.init>)[源代码][源代码]#
用于密集(全连接)层的层构造函数。
- jax.example_libraries.stax.GeneralConv(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)[源代码][源代码]#
通用卷积层的层构造函数。
- jax.example_libraries.stax.GeneralConvTranspose(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)[源代码][源代码]#
通用转置卷积层的层构造函数。
- jax.example_libraries.stax.MaxPool(window_shape, strides=None, padding='VALID', spec=None)[源代码]#
池化层的层构造函数。
- jax.example_libraries.stax.SumPool(window_shape, strides=None, padding='VALID', spec=None)[源代码]#
池化层的层构造函数。
- jax.example_libraries.stax.parallel(*layers)[源代码][源代码]#
用于并行组合层的组合器。
这个组合器产生的层通常与 FanOut 和 FanInSum 层一起使用。
- 参数:
*layers – 一系列层,每层都是一个 (init_fun, apply_fun) 对。
- 返回:
一个新的层,表示为一个 (init_fun, apply_fun) 对,代表了给定层序列的并行组合。特别是,返回的层接受一个输入序列,并返回一个与参数 layers 长度相同的输出序列。