export
methodModel.export(filepath, format="tf_saved_model")
创建一个用于推理的TF SavedModel工件.
注意: 目前这只能用于TensorFlow或JAX后端.
此方法允许您将模型导出到一个轻量级的SavedModel工件,该工件仅包含模型的前向传递(其call()
方法),并且可以通过例如TF-Serving进行服务.前向传递在名称serve()
下注册(见下面的示例).
模型的原始代码(包括您可能使用的任何自定义层)不再需要重新加载工件——它是完全独立的.
参数:
filepath: str
或 pathlib.Path
对象.保存工件的路径.
示例:
# 创建工件
model.export("path/to/location")
# 稍后,在不同的进程/环境中...
reloaded_artifact = tf.saved_model.load("path/to/location")
predictions = reloaded_artifact.serve(input_data)
如果您想自定义服务端点,可以使用较低级别的keras.export.ExportArchive
类.export()
方法在内部依赖于ExportArchive
.
ExportArchive
classkeras.export.ExportArchive()
ExportArchive 用于写入 SavedModel 制品(例如用于推理).
如果你有一个 Keras 模型或层,你希望将其导出为 SavedModel 以供服务(例如通过 TensorFlow-Serving),你可以使用 ExportArchive
来配置你需要提供的不同服务端点以及它们的签名.只需实例化一个 ExportArchive
,使用 track()
注册要使用的层或模型,然后使用 add_endpoint()
方法注册一个新的服务端点.完成后,使用 write_out()
方法保存制品.
生成的制品是一个 SavedModel,可以通过 tf.saved_model.load
重新加载.
示例:
以下是如何导出模型以进行推理.
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.write_out("path/to/location")
# 在其他地方,我们可以重新加载制品并提供服务.
# 我们添加的端点可以作为方法使用:
serving_model = tf.saved_model.load("path/to/location")
outputs = serving_model.serve(inputs)
以下是如何导出一个模型,其中一个端点用于推理,另一个端点用于训练模式前向传递(例如启用 dropout).
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="call_inference",
fn=lambda x: model.call(x, training=False),
input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.add_endpoint(
name="call_training",
fn=lambda x: model.call(x, training=True),
input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.write_out("path/to/location")
关于资源跟踪的注意事项:
ExportArchive
能够自动跟踪其端点使用的所有 tf.Variables
,因此大多数情况下调用 .track(model)
并不是严格必需的.然而,如果你的模型使用了查找层,例如 IntegerLookup
、StringLookup
或 TextVectorization
,则需要通过 .track(model)
显式跟踪.
如果你需要能够访问复活档案上的属性 variables
、trainable_variables
或 non_trainable_variables
,也需要显式跟踪.
add_endpoint
methodExportArchive.add_endpoint(name, fn, input_signature=None, jax2tf_kwargs=None)
注册一个新的服务端点.
参数:
name: 字符串,端点的名称.
fn: 一个函数.它应该仅利用在ExportArchive
跟踪的模型/层上可用的资源
(例如tf.Variable
对象或tf.lookup.StaticHashTable
对象).
你可以调用.track(model)
来跟踪一个新的模型.
函数的输入形状和数据类型必须是已知的.为此,你可以选择以下两种方式之一:
1) 确保fn
是一个至少被调用过一次的tf.function
,或者
2) 提供一个input_signature
参数来指定输入的形状和数据类型(见下文).
input_signature: 用于指定fn
输入的形状和数据类型.
tf.TensorSpec
对象的列表(每个位置输入参数对应一个).
允许嵌套参数(见下面的示例,展示了一个具有两个输入参数的Functional模型).
jax2tf_kwargs: 可选.一个字典,用于传递给jax2tf
的参数.
仅当后端是JAX时支持.参见jax2tf.convert
的文档.
native_serialization
和polymorphic_shapes
的值,如果未提供,会自动计算.
返回:
添加到归档中的包装fn
的tf.function
.
示例:
当模型有一个输入参数时,使用input_signature
参数添加端点:
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
当模型有两个位置输入参数时,使用input_signature
参数添加端点:
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[
tf.TensorSpec(shape=(None, 3), dtype=tf.float32),
tf.TensorSpec(shape=(None, 4), dtype=tf.float32),
],
)
当模型有一个输入参数是2个张量的列表时(例如一个具有两个输入的Functional模型),使用input_signature
参数添加端点:
model = keras.Model(inputs=[x1, x2], outputs=outputs)
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[
[
tf.TensorSpec(shape=(None, 3), dtype=tf.float32),
tf.TensorSpec(shape=(None, 4), dtype=tf.float32),
],
],
)
这也适用于字典输入:
model = keras.Model(inputs={"x1": x1, "x2": x2}, outputs=outputs)
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[
{
"x1": tf.TensorSpec(shape=(None, 3), dtype=tf.float32),
"x2": tf.TensorSpec(shape=(None, 4), dtype=tf.float32),
},
],
)
添加一个tf.function
端点:
@tf.function()
def serving_fn(x):
return model(x)
# 函数必须被追踪,即它必须至少被调用一次.
serving_fn(tf.random.normal(shape=(2, 3)))
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(name="serve", fn=serving_fn)
add_variable_collection
methodExportArchive.add_variable_collection(name, variables)
注册一组变量,以便在重新加载后检索.
参数:
name: 集合的字符串名称.
variables: 一个包含 tf.Variable
实例的元组/列表/集合.
示例:
export_archive = ExportArchive()
export_archive.track(model)
# 注册一个端点
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
# 保存一个变量集合
export_archive.add_variable_collection(
name="optimizer_variables", variables=model.optimizer.variables)
export_archive.write_out("path/to/location")
# 重新加载对象
revived_object = tf.saved_model.load("path/to/location")
# 检索变量
optimizer_variables = revived_object.optimizer_variables
track
methodExportArchive.track(resource)
跟踪层或模型的变量(和其他资产).
默认情况下,当您调用 add_endpoint()
时,端点函数使用的所有变量都会自动跟踪.然而,非变量资产(如查找表)需要手动跟踪.请注意,内置 Keras 层(TextVectorization
、IntegerLookup
、StringLookup
)使用的查找表在 add_endpoint()
中会自动跟踪.
参数: resource: 一个可跟踪的 TensorFlow 资源.
write_out
methodExportArchive.write_out(filepath, options=None)
将相应的SavedModel写入磁盘.
参数:
filepath: str
或 pathlib.Path
对象.
保存制品的路径.
options: tf.saved_model.SaveOptions
对象,指定
SavedModel保存选项.
关于TF-Serving的说明:通过add_endpoint()
注册的所有端点
在SavedModel制品中对TF-Serving可见.此外,
第一个注册的端点在别名"serving_default"
下可见(除非已经手动注册了名为
"serving_default"
的端点),
因为TF-Serving要求设置此端点.