Keras 3 API 文档 / 层 API / 特定于后端的层 / Tensorflow SavedModel 层

Tensorflow SavedModel 层

[source]

TFSMLayer class

keras.layers.TFSMLayer(
    filepath,
    call_endpoint="serve",
    call_training_endpoint=None,
    trainable=True,
    name=None,
    dtype=None,
)

重新加载通过SavedModel / ExportArchive保存的Keras模型/层.

参数: filepath: strpathlib.Path 对象.SavedModel的路径. call_endpoint: 用作重新加载层call()方法的端点名称.如果SavedModel是通过model.export()创建的, 则默认端点名称为'serve'.在其他情况下,可能命名为'serving_default'.

示例:

model.export("path/to/artifact")
reloaded_layer = TFSMLayer("path/to/artifact")
outputs = reloaded_layer(inputs)

重新加载的对象可以像常规Keras层一样使用,并支持其可训练权重的训练/微调.请注意,重新加载的对象不保留原始对象的内部结构或自定义方法——它是围绕保存函数创建的一个全新的层.

限制:

  • 仅支持具有单个inputs张量参数(可以是张量的字典/元组/列表)的调用端点.对于具有多个独立输入张量参数的端点,考虑子类化TFSMLayer并实现具有自定义签名的call()方法.
  • 如果您需要训练时行为与推理时行为不同(即,如果您需要重新加载的对象在__call__()中支持training=True参数),请确保训练时调用函数作为独立端点保存在工件中,并通过call_training_endpoint参数将其名称提供给TFSMLayer.