save
methodModel.save(filepath, overwrite=True, zipped=None, **kwargs)
保存模型为.keras
文件.
参数:
filepath: str
或 pathlib.Path
对象.
保存模型的路径.必须以.keras
结尾(除非通过zipped=False
保存为未压缩的目录).
overwrite: 是否应该覆盖目标位置上已有的模型,或者通过交互式提示询问用户.
zipped: 是否将模型保存为压缩的.keras
存档(默认在本地保存时),或保存为未压缩的目录(默认在Hugging Face Hub上保存时).
示例:
model = keras.Sequential(
[
keras.layers.Dense(5, input_shape=(3,)),
keras.layers.Softmax(),
],
)
model.save("model.keras")
loaded_model = keras.saving.load_model("model.keras")
x = keras.random.uniform((10, 3))
assert np.allclose(model.predict(x), loaded_model.predict(x))
注意 model.save()
是 keras.saving.save_model()
的别名.
保存的.keras
文件包含:
因此模型可以被重新实例化为完全相同的状态.
save_model
functionkeras.saving.save_model(model, filepath, overwrite=True, zipped=None, **kwargs)
将模型保存为.keras
文件.
参数:
model: 要保存的Keras模型实例.
filepath: str
或pathlib.Path
对象.保存模型的路径.
overwrite: 是否应该覆盖目标位置上任何现有的模型,或者通过交互式提示询问用户.
zipped: 是否将模型保存为压缩的.keras
存档(默认在本地保存时),或者保存为未压缩的目录(默认在保存到Hugging Face Hub时).
示例:
model = keras.Sequential(
[
keras.layers.Dense(5, input_shape=(3,)),
keras.layers.Softmax(),
],
)
model.save("model.keras")
loaded_model = keras.saving.load_model("model.keras")
x = keras.random.uniform((10, 3))
assert np.allclose(model.predict(x), loaded_model.predict(x))
注意,model.save()
是keras.saving.save_model()
的别名.
保存的.keras
文件是一个包含以下内容的zip
存档:
因此,模型可以在完全相同的状态下重新实例化.
load_model
functionkeras.saving.load_model(filepath, custom_objects=None, compile=True, safe_mode=True)
加载通过 model.save()
保存的模型.
参数:
filepath: str
或 pathlib.Path
对象,保存模型文件的路径.
custom_objects: 可选的字典,将名称(字符串)映射到自定义类或函数,
在反序列化过程中需要考虑.
compile: 布尔值,是否在加载后编译模型.
safe_mode: 布尔值,是否禁止不安全的 lambda
反序列化.
当 safe_mode=False
时,加载对象有可能触发任意代码执行.
此参数仅适用于 Keras v3 模型格式.默认为 True
.
返回:
一个 Keras 模型实例.如果原始模型已编译,
并且参数 compile=True
被设置,则返回的模型将被编译.
否则,模型将保持未编译状态.
示例:
model = keras.Sequential([
keras.layers.Dense(5, input_shape=(3,)),
keras.layers.Softmax()])
model.save("model.keras")
loaded_model = keras.saving.load_model("model.keras")
x = np.random.random((10, 3))
assert np.allclose(model.predict(x), loaded_model.predict(x))
注意,模型变量在被重新加载后可能具有不同的名称值
(var.name
属性,例如 "dense_1/kernel:0"
).
建议你使用层属性来访问特定变量,例如 model.get_layer("dense_1").kernel
.