Flax层

[source]

FlaxLayer class

keras.layers.FlaxLayer(module, method=None, variables=None, **kwargs)

Keras层,封装了一个Flax模块.

该层使得在使用JAX作为Keras后端时,能够以flax.linen.Module 实例的形式使用Flax组件.

可以通过method参数指定用于前向传播的模块方法,默认为__call__.该方法必须带有以下确切名称的参数:

  • self(如果方法是绑定到模块的,这是__call__的默认情况),否则为module以传递模块.
  • inputs:模型的输入,一个JAX数组或数组的PyTree.
  • training (可选):指定我们是否处于训练模式或推理模式的参数,在训练模式下传递True.

FlaxLayer自动处理模型的不可训练状态和所需的RNG.注意,flax.linen.Module.apply()mutable参数设置为DenyList(["params"]),因此假设"params"集合外的所有变量都是不可训练的权重.

以下示例展示了如何从Flax Module创建一个使用默认__call__方法且无训练参数的FlaxLayer:

class MyFlaxModule(flax.linen.Module):
    @flax.linen.compact
    def __call__(self, inputs):
        x = inputs
        x = flax.linen.Conv(features=32, kernel_size=(3, 3))(x)
        x = flax.linen.relu(x)
        x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = flax.linen.Dense(features=200)(x)
        x = flax.linen.relu(x)
        x = flax.linen.Dense(features=10)(x)
        x = flax.linen.softmax(x)
        return x

flax_module = MyFlaxModule()
keras_layer = FlaxLayer(flax_module)

以下示例展示了如何包装模块方法以符合所需签名.这允许有多个输入参数和一个不同名称和值的训练参数.这还展示了如何使用未绑定到模块的函数.

class MyFlaxModule(flax.linen.Module):
    @flax.linen.compact
    def forward(self, input1, input2, deterministic):
        ...
        return outputs

def my_flax_module_wrapper(module, inputs, training):
    input1, input2 = inputs
    return module.forward(input1, input2, not training)

flax_module = MyFlaxModule()
keras_layer = FlaxLayer(
    module=flax_module,
    method=my_flax_module_wrapper,
)

参数: module: flax.linen.Module或其子类的实例. method: 调用模型的方法.这通常是Module中的一个方法.如果未提供,则使用__call__方法.method也可以是未在Module中定义的函数,在这种情况下,它必须将Module作为第一个参数.它用于Module.initModule.apply.详细信息记录在flax.linen.Module.apply()method参数中. variables: 包含模块所有变量的dict,格式与flax.linen.Module.init()返回的格式相同. 应包含一个"params"键和,如果适用,其他键用于非训练状态的变量集合.这允许传递训练参数和学习的非训练状态或控制初始化.如果传递None,则在构建时调用模块的init函数以初始化模型的变量.