JaxLayer
classkeras.layers.JaxLayer(
call_fn, init_fn=None, params=None, state=None, seed=None, **kwargs
)
Keras层,封装了一个JAX模型.
该层在使用JAX作为Keras后端时,启用Keras中使用JAX组件的功能.
该层接受JAX模型作为函数call_fn
,该函数必须接受以下具有这些确切名称的参数:
params
:模型的可训练参数.state
(可选):模型的不可训练状态.如果模型没有不可训练状态,可以省略.rng
(可选):一个jax.random.PRNGKey
实例.如果模型在训练或推理过程中不需要随机数生成器,可以省略.inputs
:模型的输入,一个JAX数组或数组的PyTree
.training
(可选):指定我们是否处于训练模式或推理模式的参数,True
表示训练模式.如果模型在训练模式和推理模式下的行为相同,可以省略.inputs
参数是必须的.模型的输入必须通过单个参数提供.如果JAX模型将多个输入作为单独的参数,则必须将它们组合成单个结构,例如在tuple
或dict
中.
模型的params
和state
的初始化可以由该层处理,在这种情况下必须提供init_fn
参数.这允许模型用正确的形状动态初始化.或者,如果形状已知,可以使用params
参数和可选的state
参数创建已初始化的模型.
如果提供了init_fn
函数,则必须接受以下具有这些确切名称的参数:
rng
:一个jax.random.PRNGKey
实例.inputs
:一个JAX数组或具有占位符值的数组的PyTree
,以提供输入的形状.training
(可选):指定我们是否处于训练模式或推理模式的参数.True
总是传递给init_fn
.无论call_fn
是否有training
参数,都可以省略.对于具有不可训练状态的JAX模型:
call_fn
必须有一个state
参数call_fn
必须返回一个包含模型输出和新不可训练状态的tuple
init_fn
必须返回一个包含模型初始可训练参数和初始不可训练状态的tuple
.这段代码展示了一个可能的call_fn
和init_fn
签名的组合,用于具有不可训练状态的模型.在这个例子中,模型在call_fn
中有一个training
参数和一个rng
参数.
def stateful_call(params, state, rng, inputs, training):
outputs = ...
new_state = ...
return outputs, new_state
def stateful_init(rng, inputs):
initial_params = ...
initial_state = ...
return initial_params, initial_state
对于没有不可训练状态的JAX模型:
call_fn
不得有state
参数call_fn
必须仅返回模型的输出init_fn
必须仅返回模型的初始可训练参数.这段代码展示了一个可能的call_fn
和init_fn
签名的组合,用于没有不可训练状态的模型.在这个例子中,模型在call_fn
中没有training
参数,也没有rng
参数.
def stateless_call(params, inputs):
outputs = ...
return outputs
def stateless_init(rng, inputs):
initial_params = ...
return initial_params
如果模型的签名与JaxLayer
所需的签名不同,可以轻松编写一个包装方法来适应参数.这个例子展示了一个模型,该模型将多个输入作为单独的参数,期望在dict
中有多