Keras 3 API 文档 / 回调API / 提前停止

提前停止

[source]

EarlyStopping class

keras.callbacks.EarlyStopping(
    monitor="val_loss",
    min_delta=0,
    patience=0,
    verbose=0,
    mode="auto",
    baseline=None,
    restore_best_weights=False,
    start_from_epoch=0,
)

停止训练当监控的指标不再改善时.

假设训练的目标是最小化损失.在这种情况下,监控的指标将是'loss',模式将是'min'.一个model.fit()训练循环将在每个 epoch 结束时检查损失是否不再减少,考虑到min_deltapatience(如果适用).一旦发现损失不再减少,model.stop_training将被标记为True,训练终止.

需要监控的量需要在logs字典中可用.为了实现这一点,在model.compile()时传递损失或指标.

参数: monitor: 要监控的量.默认为"val_loss". min_delta: 监控量中被视为改善的最小变化,即小于min_delta的绝对变化将不计为改善.默认为0. patience: 在训练停止之前没有改善的 epoch 数.默认为0. verbose: 冗长模式,0 或 1.模式 0 是静默的,模式 1 在回调采取行动时显示消息.默认为0. mode: 其中之一{"auto", "min", "max"}.在min模式下,训练将在监控量停止减少时停止;在"max"模式下,训练将在监控量停止增加时停止;在"auto"模式下,方向会根据监控量的名称自动推断.默认为"auto". baseline: 监控量的基线值.如果不是None,训练将在模型没有显示出超过基线的改善时停止.默认为None. restore_best_weights: 是否从监控量最佳值的 epoch 恢复模型权重.如果为False,则使用训练最后一步获得的模型权重.无论相对于baseline的表现如何,都会恢复 epoch.如果没有 epoch 改善baseline,训练将运行patience个 epoch 并恢复该组中最佳 epoch 的权重.默认为False. start_from_epoch: 开始监控改善之前的 epoch 数.这允许一个预热期,在此期间不期望改善,因此训练不会停止.默认为0.

示例:

>>> callback = keras.callbacks.EarlyStopping(monitor='loss',
...                                               patience=3)
>>> # 当损失连续三个 epoch 没有改善时,此回调将停止训练.
>>> model = keras.models.Sequential([keras.layers.Dense(10)])
>>> model.compile(keras.optimizers.SGD(), loss='mse')
>>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
...                     epochs=10, batch_size=1, callbacks=[callback],
...                     verbose=0)
>>> len(history.history['loss'])  # 只运行了 4 个 epoch.
4