开发者指南 / 自定义保存和序列化

自定义保存和序列化

作者: Neel Kovelamudi
创建日期: 2023/03/15
最后修改: 2023/03/15
描述: 关于为您的层和模型自定义保存的更高级的指南。

在 Colab 中查看 GitHub 源代码


介绍

本指南涵盖可以在 Keras 保存中自定义的高级方法。对于大多数用户,主要 序列化、保存和导出指南 中概述的方法已经足够。

API

我们将涵盖以下 API:

  • save_assets()load_assets()
  • save_own_variables()load_own_variables()
  • get_build_config()build_from_config()
  • get_compile_config()compile_from_config()

在恢复模型时,这些方法按以下顺序执行:

  • build_from_config()
  • compile_from_config()
  • load_own_variables()
  • load_assets()

设置

import os
import numpy as np
import keras

状态保存自定义

这些方法确定在调用 model.save() 时如何保存模型层的状态。您可以重写它们以完全控制状态保存过程。

save_own_variables()load_own_variables()

这些方法在调用 model.save()keras.models.load_model() 时分别保存和加载层的状态变量。默认情况下,保存和加载的状态变量是层的权重(包括可训练和不可训练的)。下面是 save_own_variables() 的默认实现:

def save_own_variables(self, store):
    all_vars = self._trainable_weights + self._non_trainable_weights
    for i, v in enumerate(all_vars):
        store[f"{i}"] = v.numpy()

这些方法使用的存储是一个字典,可以用层变量填充。让我们来看一个自定义的示例。

示例:

@keras.utils.register_keras_serializable(package="my_custom_package")
class LayerWithCustomVariable(keras.layers.Dense):
    def __init__(self, units, **kwargs):
        super().__init__(units, **kwargs)
        self.my_variable = keras.Variable(
            np.random.random((units,)), name="my_variable", dtype="float32"
        )

    def save_own_variables(self, store):
        super().save_own_variables(store)
        # 保存变量的值
        store["variables"] = self.my_variable.numpy()

    def load_own_variables(self, store):
        # 加载时分配变量的值
        self.my_variable.assign(store["variables"])
        # 加载剩余的权重
        for i, v in enumerate(self.weights):
            v.assign(store[f"{i}"])
        # 请注意:您必须指定如何在 `load_own_variables.` 中加载所有变量(包括层权重)。

    def call(self, inputs):
        dense_out = super().call(inputs)
        return dense_out + self.my_variable


model = keras.Sequential([LayerWithCustomVariable(1)])

ref_input = np.random.random((8, 10))
ref_output = np.random.random((8, 10))
model.compile(optimizer="adam", loss="mean_squared_error")
model.fit(ref_input, ref_output)

model.save("custom_vars_model.keras")
restored_model = keras.models.load_model("custom_vars_model.keras")

np.testing.assert_allclose(
    model.layers[0].my_variable.numpy(),
    restored_model.layers[0].my_variable.numpy(),
)
 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 101ms/step - loss: 0.2908

save_assets()load_assets()

这些方法可以添加到您的模型类定义中,以存储和加载模型所需的任何额外信息。

例如,NLP 领域的层,如文本向量化层和索引查找层,可能需要在保存时将其相关的词汇表(或查找表)存储在文本文件中。

让我们看看这个工作流的基础知识,使用一个简单的文件 assets.txt

示例:

@keras.saving.register_keras_serializable(package="my_custom_package")
class LayerWithCustomAssets(keras.layers.Dense):
    def __init__(self, vocab=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.vocab = vocab

    def save_assets(self, inner_path):
        # 在保存时将词汇(句子)写入文本文件。
        with open(os.path.join(inner_path, "vocabulary.txt"), "w") as f:
            f.write(self.vocab)

    def load_assets(self, inner_path):
        # 在加载时从文本文件中读取词汇(句子)。
        with open(os.path.join(inner_path, "vocabulary.txt"), "r") as f:
            text = f.read()
        self.vocab = text.replace("<unk>", "little")


model = keras.Sequential(
    [LayerWithCustomAssets(vocab="Mary had a <unk> lamb.", units=5)]
)

x = np.random.random((10, 10))
y = model(x)

model.save("custom_assets_model.keras")
restored_model = keras.models.load_model("custom_assets_model.keras")

np.testing.assert_string_equal(
    restored_model.layers[0].vocab, "Mary had a little lamb."
)

buildcompile 的保存自定义

get_build_config()build_from_config()

这些方法一起工作以保存层的构建状态并在加载时恢复它们。

默认情况下,这仅包括一个包含层输入形状的构建配置字典,但可以覆盖这些方法以包括进一步的变量和查找表,这对于恢复构建的模型是有用的。

示例:

@keras.saving.register_keras_serializable(package="my_custom_package")
class LayerWithCustomBuild(keras.layers.Layer):
    def __init__(self, units=32, **kwargs):
        super().__init__(**kwargs)
        self.units = units

    def call(self, inputs):
        return keras.ops.matmul(inputs, self.w) + self.b

    def get_config(self):
        return dict(units=self.units, **super().get_config())

    def build(self, input_shape, layer_init):
        # 注意重写了 `build()` 以添加额外的参数。
        # 因此,我们需要在第一次执行 `call()` 之前手动调用带有 `layer_init` 参数的 build。
        super().build(input_shape)
        self._input_shape = input_shape
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer=layer_init,
            trainable=True,
        )
        self.b = self.add_weight(
            shape=(self.units,),
            initializer=layer_init,
            trainable=True,
        )
        self.layer_init = layer_init

    def get_build_config(self):
        build_config = {
            "layer_init": self.layer_init,
            "input_shape": self._input_shape,
        }  # 存储我们的初始化器用于 `build()`
        return build_config

    def build_from_config(self, config):
        # 在加载时使用参数调用 `build()`
        self.build(config["input_shape"], config["layer_init"])


custom_layer = LayerWithCustomBuild(units=16)
custom_layer.build(input_shape=(8,), layer_init="random_normal")

model = keras.Sequential(
    [
        custom_layer,
        keras.layers.Dense(1, activation="sigmoid"),
    ]
)

x = np.random.random((16, 8))
y = model(x)

model.save("custom_build_model.keras")
restored_model = keras.models.load_model("custom_build_model.keras")

np.testing.assert_equal(restored_model.layers[0].layer_init, "random_normal")
np.testing.assert_equal(restored_model.built, True)

get_compile_config()compile_from_config()

这些方法一起工作以保存模型编译时的信息(优化器、损失等),并使用这些信息恢复和重新编译模型。

覆盖这些方法对于使用自定义优化器、自定义损失等来编译恢复的模型是有用的,因为在调用 compile_from_config() 中的 model.compile 之前需要对这些进行反序列化。

让我们看一个这方面的示例。

示例:

@keras.saving.register_keras_serializable(package="my_custom_package")
def small_square_sum_loss(y_true, y_pred):
    loss = keras.ops.square(y_pred - y_true)
    loss = loss / 10.0
    loss = keras.ops.sum(loss, axis=1)
    return loss


@keras.saving.register_keras_serializable(package="my_custom_package")
def mean_pred(y_true, y_pred):
    return keras.ops.mean(y_pred)


@keras.saving.register_keras_serializable(package="my_custom_package")
class ModelWithCustomCompile(keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.dense1 = keras.layers.Dense(8, activation="relu")
        self.dense2 = keras.layers.Dense(4, activation="softmax")

    def call(self, inputs):
        x = self.dense1(inputs)
        return self.dense2(x)

    def compile(self, optimizer, loss_fn, metrics):
        super().compile(optimizer=optimizer, loss=loss_fn, metrics=metrics)
        self.model_optimizer = optimizer
        self.loss_fn = loss_fn
        self.loss_metrics = metrics

    def get_compile_config(self):
        # 这些参数将在保存时序列化。
        return {
            "model_optimizer": self.model_optimizer,
            "loss_fn": self.loss_fn,
            "metric": self.loss_metrics,
        }

    def compile_from_config(self, config):
        # 反序列化编译参数(重要,因为许多是自定义的)
        optimizer = keras.utils.deserialize_keras_object(config["model_optimizer"])
        loss_fn = keras.utils.deserialize_keras_object(config["loss_fn"])
        metrics = keras.utils.deserialize_keras_object(config["metric"])

        # 使用反序列化的参数调用compile
        self.compile(optimizer=optimizer, loss_fn=loss_fn, metrics=metrics)


model = ModelWithCustomCompile()
model.compile(
    optimizer="SGD", loss_fn=small_square_sum_loss, metrics=["accuracy", mean_pred]
)

x = np.random.random((4, 8))
y = np.random.random((4,))

model.fit(x, y)

model.save("custom_compile_model.keras")
restored_model = keras.models.load_model("custom_compile_model.keras")

np.testing.assert_equal(model.model_optimizer, restored_model.model_optimizer)
np.testing.assert_equal(model.loss_fn, restored_model.loss_fn)
np.testing.assert_equal(model.loss_metrics, restored_model.loss_metrics)
 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 79ms/step - 准确率: 0.0000e+00 - 损失: 0.0627 - 平均度量包装器: 0.2500

结论

使用本教程中学到的方法允许多种使用情况,允许保存和加载具有奇特资产和状态元素的复杂模型。总结如下:

  • save_own_variablesload_own_variables 决定你的状态如何被保存和加载。
  • save_assetsload_assets 可以添加以存储和加载模型所需的任何附加信息。
  • get_build_configbuild_from_config 保存和恢复模型的构建状态。
  • get_compile_configcompile_from_config 保存和恢复模型的编译状态。