Jax层

[source]

JaxLayer class

keras.layers.JaxLayer(
    call_fn, init_fn=None, params=None, state=None, seed=None, **kwargs
)

Keras层,封装了一个JAX模型.

该层在使用JAX作为Keras后端时,启用Keras中使用JAX组件的功能.

模型函数

该层接受JAX模型作为函数call_fn,该函数必须接受以下具有这些确切名称的参数:

  • params:模型的可训练参数.
  • state可选):模型的不可训练状态.如果模型没有不可训练状态,可以省略.
  • rng可选):一个jax.random.PRNGKey实例.如果模型在训练或推理过程中不需要随机数生成器,可以省略.
  • inputs:模型的输入,一个JAX数组或数组的PyTree.
  • training可选):指定我们是否处于训练模式或推理模式的参数,True表示训练模式.如果模型在训练模式和推理模式下的行为相同,可以省略.

inputs参数是必须的.模型的输入必须通过单个参数提供.如果JAX模型将多个输入作为单独的参数,则必须将它们组合成单个结构,例如在tupledict中.

模型权重初始化

模型的paramsstate的初始化可以由该层处理,在这种情况下必须提供init_fn参数.这允许模型用正确的形状动态初始化.或者,如果形状已知,可以使用params参数和可选的state参数创建已初始化的模型.

如果提供了init_fn函数,则必须接受以下具有这些确切名称的参数:

  • rng:一个jax.random.PRNGKey实例.
  • inputs:一个JAX数组或具有占位符值的数组的PyTree,以提供输入的形状.
  • training可选):指定我们是否处于训练模式或推理模式的参数.True总是传递给init_fn.无论call_fn是否有training参数,都可以省略.

具有不可训练状态的模型

对于具有不可训练状态的JAX模型:

  • call_fn必须有一个state参数
  • call_fn必须返回一个包含模型输出和新不可训练状态的tuple
  • init_fn必须返回一个包含模型初始可训练参数和初始不可训练状态的tuple.

这段代码展示了一个可能的call_fninit_fn签名的组合,用于具有不可训练状态的模型.在这个例子中,模型在call_fn中有一个training参数和一个rng参数.

def stateful_call(params, state, rng, inputs, training):
    outputs = ...
    new_state = ...
    return outputs, new_state

def stateful_init(rng, inputs):
    initial_params = ...
    initial_state = ...
    return initial_params, initial_state

没有不可训练状态的模型

对于没有不可训练状态的JAX模型:

  • call_fn不得有state参数
  • call_fn必须仅返回模型的输出
  • init_fn必须仅返回模型的初始可训练参数.

这段代码展示了一个可能的call_fninit_fn签名的组合,用于没有不可训练状态的模型.在这个例子中,模型在call_fn中没有training参数,也没有rng参数.

def stateless_call(params, inputs):
    outputs = ...
    return outputs

def stateless_init(rng, inputs):
    initial_params = ...
    return initial_params

符合所需签名

如果模型的签名与JaxLayer所需的签名不同,可以轻松编写一个包装方法来适应参数.这个例子展示了一个模型,该模型将多个输入作为单独的参数,期望在dict中有多