JaxLayer classkeras.layers.JaxLayer(
call_fn, init_fn=None, params=None, state=None, seed=None, **kwargs
)
Keras层,封装了一个JAX模型.
该层在使用JAX作为Keras后端时,启用Keras中使用JAX组件的功能.
该层接受JAX模型作为函数call_fn,该函数必须接受以下具有这些确切名称的参数:
params:模型的可训练参数.state(可选):模型的不可训练状态.如果模型没有不可训练状态,可以省略.rng(可选):一个jax.random.PRNGKey实例.如果模型在训练或推理过程中不需要随机数生成器,可以省略.inputs:模型的输入,一个JAX数组或数组的PyTree.training(可选):指定我们是否处于训练模式或推理模式的参数,True表示训练模式.如果模型在训练模式和推理模式下的行为相同,可以省略.inputs参数是必须的.模型的输入必须通过单个参数提供.如果JAX模型将多个输入作为单独的参数,则必须将它们组合成单个结构,例如在tuple或dict中.
模型的params和state的初始化可以由该层处理,在这种情况下必须提供init_fn参数.这允许模型用正确的形状动态初始化.或者,如果形状已知,可以使用params参数和可选的state参数创建已初始化的模型.
如果提供了init_fn函数,则必须接受以下具有这些确切名称的参数:
rng:一个jax.random.PRNGKey实例.inputs:一个JAX数组或具有占位符值的数组的PyTree,以提供输入的形状.training(可选):指定我们是否处于训练模式或推理模式的参数.True总是传递给init_fn.无论call_fn是否有training参数,都可以省略.对于具有不可训练状态的JAX模型:
call_fn必须有一个state参数call_fn必须返回一个包含模型输出和新不可训练状态的tupleinit_fn必须返回一个包含模型初始可训练参数和初始不可训练状态的tuple.这段代码展示了一个可能的call_fn和init_fn签名的组合,用于具有不可训练状态的模型.在这个例子中,模型在call_fn中有一个training参数和一个rng参数.
def stateful_call(params, state, rng, inputs, training):
outputs = ...
new_state = ...
return outputs, new_state
def stateful_init(rng, inputs):
initial_params = ...
initial_state = ...
return initial_params, initial_state
对于没有不可训练状态的JAX模型:
call_fn不得有state参数call_fn必须仅返回模型的输出init_fn必须仅返回模型的初始可训练参数.这段代码展示了一个可能的call_fn和init_fn签名的组合,用于没有不可训练状态的模型.在这个例子中,模型在call_fn中没有training参数,也没有rng参数.
def stateless_call(params, inputs):
outputs = ...
return outputs
def stateless_init(rng, inputs):
initial_params = ...
return initial_params
如果模型的签名与JaxLayer所需的签名不同,可以轻松编写一个包装方法来适应参数.这个例子展示了一个模型,该模型将多个输入作为单独的参数,期望在dict中有多