Keras 3 API 文档 / 回调API / 模型检查点

模型检查点

[source]

ModelCheckpoint class

keras.callbacks.ModelCheckpoint(
    filepath,
    monitor="val_loss",
    verbose=0,
    save_best_only=False,
    save_weights_only=False,
    mode="auto",
    save_freq="epoch",
    initial_value_threshold=None,
)

保存Keras模型或模型权重以某种频率的回调.

ModelCheckpoint 回调与使用 model.fit() 进行训练结合使用,以某种间隔保存模型或权重(在检查点文件中),以便稍后加载模型或权重以从保存的状态继续训练.

此回调提供的一些选项包括:

  • 是否仅保留迄今为止达到"最佳性能”的模型,或者是否在每个 epoch 结束时无论性能如何都保存模型.
  • "最佳”的定义;监控哪个量以及是应该最大化还是最小化.
  • 应保存的频率.目前,回调支持在每个 epoch 结束时保存,或在固定数量的训练批次后保存.
  • 仅保存权重,还是保存整个模型.

示例:

model.compile(loss=..., optimizer=...,
              metrics=['accuracy'])

EPOCHS = 10
checkpoint_filepath = '/tmp/ckpt/checkpoint.model.keras'
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)

# 如果模型是目前为止最好的,则在每个 epoch 结束时保存模型.
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])

# 可以加载被认为是最好的模型 -
keras.models.load_model(checkpoint_filepath)

# 或者,可以仅检查点模型权重 -
checkpoint_filepath = '/tmp/ckpt/checkpoint.weights.h5'
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)

# 如果模型是目前为止最好的,则在每个 epoch 结束时保存模型权重.
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])

# 可以加载被认为是最好的模型权重 -
model.load_weights(checkpoint_filepath)

参数: filepath: 字符串或 PathLike,保存模型文件的路径. filepath 可以包含命名格式化选项, 这些选项将用 epoch 的值和 logs 中的键填充 (在 on_epoch_end 中传递). filepath 名称需要在 save_weights_only=True 时以 ".weights.h5" 结尾, 或者在检查点保存整个模型时以 ".keras" 结尾(默认). 例如: 如果 filepath"{epoch:02d}-{val_loss:.2f}.keras",则模型检查点将使用 epoch 编号和验证损失保存在文件名中.filepath 的目录不应被任何其他回调重用以避免冲突. monitor: 要监控的指标名称.通常由 Model.compile 方法设置指标.注意: * 在名称前加上 "val_" 以监控验证指标. * 使用 "loss""val_loss" 监控模型的总损失. * 如果将指标指定为字符串,如 "accuracy",则传递相同的字符串(带或不带 "val_" 前缀). * 如果传递 metrics.Metric 对象,monitor 应设置为 metric.name * 如果不确定指标名称,可以检查 history.history 字典的内容,该字典由 history = model.fit() 返回 * 多输出模型在指标名称上设置额外的前缀. verbose: 详细模式,0 或 1.模式 0 静默,模式 1 在回调采取行动时显示消息. save_best_only: 如果 save_best_only=True,则仅在模型被认为是"最佳”时保存,并且根据监控的数量,最新的最佳模型将不会被覆盖.如果 filepath 不包含格式化选项如 {epoch},则 filepath 将被每个新的更好模型覆盖. mode: 其中一个 {"auto", "min", "max"}.如果 save_best_only=True,则根据监控数量的最大化或最小化决定是否覆盖当前保存文件.对于 val_acc,这应该是 "max",对于 val_loss,这应该是 "min",等等.在 "auto" 模式下,如果监控的数量是 "acc" 或以 "fmeasure" 开头,则模式设置为 "max",其余数量设置为 "min". save_weights_only: 如果 True,则仅保存模型的权重(model.save_weights(filepath)),否则保存整个模型(model.save(filepath)). save_freq: "epoch" 或整数.使用 "epoch" 时,回调在每个 epoch 后保存模型.使用整数时,回调在此许多批次后保存模型.如果 Model 编译时设置了 steps_per_execution=N,则保存条件将在每第 N 个批次检查.请注意,如果保存不与 epoch 对齐,监控的指标可能不太可靠(它可能仅反映 1 个批次,因为指标在每个 epoch 重置).默认为 "epoch". initial_value_threshold: 浮点数,监控指标的初始"最佳”值.仅在 save_best_value=True 时适用.仅在当前模型的性能优于此值时覆盖已保存的模型权重.