Keras 3 API 文档 / 模型 API / 保存和序列化 / 模型导出用于推理

模型导出用于推理

[source]

export method

Model.export(filepath, format="tf_saved_model")

创建一个用于推理的TF SavedModel工件.

注意: 目前这只能用于TensorFlow或JAX后端.

此方法允许您将模型导出到一个轻量级的SavedModel工件,该工件仅包含模型的前向传递(其call()方法),并且可以通过例如TF-Serving进行服务.前向传递在名称serve()下注册(见下面的示例).

模型的原始代码(包括您可能使用的任何自定义层)不再需要重新加载工件——它是完全独立的.

参数: filepath: strpathlib.Path 对象.保存工件的路径.

示例:

# 创建工件
model.export("path/to/location")

# 稍后,在不同的进程/环境中...
reloaded_artifact = tf.saved_model.load("path/to/location")
predictions = reloaded_artifact.serve(input_data)

如果您想自定义服务端点,可以使用较低级别的keras.export.ExportArchive类.export()方法在内部依赖于ExportArchive.


[source]

ExportArchive class

keras.export.ExportArchive()

ExportArchive 用于写入 SavedModel 制品(例如用于推理).

如果你有一个 Keras 模型或层,你希望将其导出为 SavedModel 以供服务(例如通过 TensorFlow-Serving),你可以使用 ExportArchive 来配置你需要提供的不同服务端点以及它们的签名.只需实例化一个 ExportArchive,使用 track() 注册要使用的层或模型,然后使用 add_endpoint() 方法注册一个新的服务端点.完成后,使用 write_out() 方法保存制品.

生成的制品是一个 SavedModel,可以通过 tf.saved_model.load 重新加载.

示例:

以下是如何导出模型以进行推理.

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.write_out("path/to/location")

# 在其他地方,我们可以重新加载制品并提供服务.
# 我们添加的端点可以作为方法使用:
serving_model = tf.saved_model.load("path/to/location")
outputs = serving_model.serve(inputs)

以下是如何导出一个模型,其中一个端点用于推理,另一个端点用于训练模式前向传递(例如启用 dropout).

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="call_inference",
    fn=lambda x: model.call(x, training=False),
    input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.add_endpoint(
    name="call_training",
    fn=lambda x: model.call(x, training=True),
    input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.write_out("path/to/location")

关于资源跟踪的注意事项:

ExportArchive 能够自动跟踪其端点使用的所有 tf.Variables,因此大多数情况下调用 .track(model) 并不是严格必需的.然而,如果你的模型使用了查找层,例如 IntegerLookupStringLookupTextVectorization,则需要通过 .track(model) 显式跟踪.

如果你需要能够访问复活档案上的属性 variablestrainable_variablesnon_trainable_variables,也需要显式跟踪.


[source]

add_endpoint method

ExportArchive.add_endpoint(name, fn, input_signature=None, jax2tf_kwargs=None)

注册一个新的服务端点.

参数: name: 字符串,端点的名称. fn: 一个函数.它应该仅利用在ExportArchive跟踪的模型/层上可用的资源 (例如tf.Variable对象或tf.lookup.StaticHashTable对象). 你可以调用.track(model)来跟踪一个新的模型. 函数的输入形状和数据类型必须是已知的.为此,你可以选择以下两种方式之一: 1) 确保fn是一个至少被调用过一次的tf.function,或者 2) 提供一个input_signature参数来指定输入的形状和数据类型(见下文). input_signature: 用于指定fn输入的形状和数据类型. tf.TensorSpec对象的列表(每个位置输入参数对应一个). 允许嵌套参数(见下面的示例,展示了一个具有两个输入参数的Functional模型). jax2tf_kwargs: 可选.一个字典,用于传递给jax2tf的参数. 仅当后端是JAX时支持.参见jax2tf.convert的文档. native_serializationpolymorphic_shapes的值,如果未提供,会自动计算.

返回: 添加到归档中的包装fntf.function.

示例:

当模型有一个输入参数时,使用input_signature参数添加端点:

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)

当模型有两个位置输入参数时,使用input_signature参数添加端点:

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[
        tf.TensorSpec(shape=(None, 3), dtype=tf.float32),
        tf.TensorSpec(shape=(None, 4), dtype=tf.float32),
    ],
)

当模型有一个输入参数是2个张量的列表时(例如一个具有两个输入的Functional模型),使用input_signature参数添加端点:

model = keras.Model(inputs=[x1, x2], outputs=outputs)

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[
        [
            tf.TensorSpec(shape=(None, 3), dtype=tf.float32),
            tf.TensorSpec(shape=(None, 4), dtype=tf.float32),
        ],
    ],
)

这也适用于字典输入:

model = keras.Model(inputs={"x1": x1, "x2": x2}, outputs=outputs)

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[
        {
            "x1": tf.TensorSpec(shape=(None, 3), dtype=tf.float32),
            "x2": tf.TensorSpec(shape=(None, 4), dtype=tf.float32),
        },
    ],
)

添加一个tf.function端点:

@tf.function()
def serving_fn(x):
    return model(x)

# 函数必须被追踪,即它必须至少被调用一次.
serving_fn(tf.random.normal(shape=(2, 3)))

export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(name="serve", fn=serving_fn)

[source]

add_variable_collection method

ExportArchive.add_variable_collection(name, variables)

注册一组变量,以便在重新加载后检索.

参数: name: 集合的字符串名称. variables: 一个包含 tf.Variable 实例的元组/列表/集合.

示例:

export_archive = ExportArchive()
export_archive.track(model)
# 注册一个端点
export_archive.add_endpoint(
    name="serve",
    fn=model.call,
    input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
# 保存一个变量集合
export_archive.add_variable_collection(
    name="optimizer_variables", variables=model.optimizer.variables)
export_archive.write_out("path/to/location")

# 重新加载对象
revived_object = tf.saved_model.load("path/to/location")
# 检索变量
optimizer_variables = revived_object.optimizer_variables

[source]

track method

ExportArchive.track(resource)

跟踪层或模型的变量(和其他资产).

默认情况下,当您调用 add_endpoint() 时,端点函数使用的所有变量都会自动跟踪.然而,非变量资产(如查找表)需要手动跟踪.请注意,内置 Keras 层(TextVectorizationIntegerLookupStringLookup)使用的查找表在 add_endpoint() 中会自动跟踪.

参数: resource: 一个可跟踪的 TensorFlow 资源.


[source]

write_out method

ExportArchive.write_out(filepath, options=None)

将相应的SavedModel写入磁盘.

参数: filepath: strpathlib.Path 对象. 保存制品的路径. options: tf.saved_model.SaveOptions 对象,指定 SavedModel保存选项.

关于TF-Serving的说明:通过add_endpoint()注册的所有端点 在SavedModel制品中对TF-Serving可见.此外, 第一个注册的端点在别名"serving_default"下可见(除非已经手动注册了名为 "serving_default"的端点), 因为TF-Serving要求设置此端点.