序列化工具

[source]

serialize_keras_object function

keras.saving.serialize_keras_object(obj)

获取通过序列化Keras对象得到的配置字典.

serialize_keras_object() 将一个Keras对象序列化为一个表示该对象的Python字典,并且是deserialize_keras_object()的互逆函数.有关配置格式的更多信息,请参见deserialize_keras_object().

参数: obj: 要序列化的Keras对象.

返回: 一个表示该对象的Python字典.该Python字典可以通过deserialize_keras_object()进行反序列化.


[source]

deserialize_keras_object function

keras.saving.deserialize_keras_object(
    config, custom_objects=None, safe_mode=True, **kwargs
)

通过反序列化配置字典来检索对象.

配置字典是一个由一组键值对组成的Python字典,表示一个Keras对象,例如OptimizerLayerMetrics等.保存和加载库使用以下键来记录Keras对象的信息:

  • class_name:字符串.这是类的名称,与源代码中定义的完全一致,例如"LossesContainer".
  • config:字典.库定义或用户定义的键值对,存储对象的配置,通过object.get_config()获取.
  • module:字符串.Python模块的路径.内置的Keras类期望有前缀keras.
  • registered_name:字符串.通过keras.saving.register_keras_serializable(package, name) API注册的类的键.键的格式为'{package}>{name}',其中packagename是传递给register_keras_serializable()的参数.如果未提供name,则使用类名.如果registered_name成功解析为一个已注册的类,字典中的class_nameconfig值将不会被使用.registered_name仅用于非内置类.

例如,以下字典表示内置的Adam优化器及其相关配置:

dict_structure = {
    "class_name": "Adam",
    "config": {
        "amsgrad": false,
        "beta_1": 0.8999999761581421,
        "beta_2": 0.9990000128746033,
        "decay": 0.0,
        "epsilon": 1e-07,
        "learning_rate": 0.0010000000474974513,
        "name": "Adam"
    },
    "module": "keras.optimizers",
    "registered_name": None
}
# 返回一个与原始对象相同的`Adam`实例.
deserialize_keras_object(dict_structure)

如果类没有导出的Keras命名空间,库通过其moduleclass_name跟踪它.例如:

dict_structure = {
  "class_name": "MetricsList",
  "config": {
      ...
  },
  "module": "keras.trainers.compile_utils",
  "registered_name": "MetricsList"
}

# 返回一个与原始对象相同的`MetricsList`实例.
deserialize_keras_object(dict_structure)

以下字典表示用户自定义的MeanSquaredError损失:

@keras.saving.register_keras_serializable(package='my_package')
class ModifiedMeanSquaredError(keras.losses.MeanSquaredError):
  ...

dict_structure = {
    "class_name": "ModifiedMeanSquaredError",
    "config": {
        "fn": "mean_squared_error",
        "name": "mean_squared_error",
        "reduction": "auto"
    },
    "registered_name": "my_package>ModifiedMeanSquaredError"
}
# 返回`ModifiedMeanSquaredError`对象
deserialize_keras_object(dict_structure)

参数: config:描述对象的Python字典. custom_objects:包含自定义对象名称与相应类或函数映射的Python字典. safe_mode:布尔值,是否禁止不安全的lambda反序列化.当safe_mode=False时,加载对象有可能触发任意代码执行.此参数仅适用于Keras v3模型格式.默认为True.

返回: 由config字典描述的对象.


[source]

CustomObjectScope class

keras.saving.custom_object_scope(custom_objects)

向Keras反序列化内部暴露自定义类/函数.

with custom_object_scope(objects_dict)作用域下,Keras方法如keras.models.load_model()keras.models.model_from_config()将能够反序列化任何在保存的配置中引用的自定义对象(例如自定义层或度量标准).

示例:

考虑一个自定义正则化器my_regularizer:

layer = Dense(3, kernel_regularizer=my_regularizer)
# 配置包含对`my_regularizer`的引用
config = layer.get_config()
...
# 之后:
with custom_object_scope({'my_regularizer': my_regularizer}):
    layer = Dense.from_config(config)

参数: custom_objects: {str: object}对的字典,其中str键是对象名称.


[source]

get_custom_objects function

keras.saving.get_custom_objects()

检索对自定义对象全局字典的实时引用.

使用 custom_object_scope() 设置的自定义对象不会添加到自定义对象的全局字典中,也不会出现在返回的字典中.

示例:

get_custom_objects().clear()
get_custom_objects()['MyObject'] = MyObject

返回: 全局字典,映射已注册的类名到类.


[source]

register_keras_serializable function

keras.saving.register_keras_serializable(package="Custom", name=None)

注册一个对象到Keras序列化框架.

这个装饰器将装饰的类或函数注入到Keras自定义对象字典中,以便它可以被序列化和反序列化,而无需在用户提供的自定义对象字典中有一个条目.它还注入了一个函数,Keras将调用该函数来获取对象的可序列化字符串键.

注意,为了被序列化和反序列化,类必须实现get_config()方法.函数没有这个要求.

对象将以键'package>name'注册,其中name默认为对象名称,如果未传递.

示例:

# 注意,这里使用`'my_package'`作为`package`参数,并且由于没有提供`name`参数,`'MyDense'`被用作`name`.
@register_keras_serializable('my_package')
class MyDense(keras.layers.Dense):
    pass

assert get_registered_object('my_package>MyDense') == MyDense
assert get_registered_name(MyDense) == 'my_package>MyDense'

参数: package: 这个类所属的包.这用于标识类的key(即"package>name").注意,这是传递给装饰器的第一个参数. name: 在这个包中序列化这个类的名称.如果未提供或为None,将使用类的名称(注意,当装饰器只使用一个参数时,这个参数成为package).

返回: 一个装饰器,将装饰的类用传递的名称注册.