BackupAndRestore
classkeras.callbacks.BackupAndRestore(
backup_dir, save_freq="epoch", delete_checkpoint=True
)
备份和恢复训练状态的回调.
BackupAndRestore
回调旨在通过在每个 epoch 结束时备份训练状态到一个临时的检查点文件中,来从 Model.fit
执行过程中的中断中恢复训练.每个备份都会覆盖之前写入的检查点文件,因此在任何给定时间最多只有一个这样的检查点文件用于备份/恢复目的.
如果训练在完成之前重新启动,训练状态(包括 Model
权重和 epoch 编号)将在新的 Model.fit
运行开始时恢复到最近保存的状态.在 Model.fit
运行完成时,临时检查点文件将被删除.
请注意,用户有责任在中断后恢复作业.此回调对于故障容错目的的备份和恢复机制非常重要,并且期望从之前的检查点恢复的模型与用于备份的模型相同.如果用户更改传递给 compile 或 fit 的参数,为故障容错保存的检查点可能会失效.
示例:
```python
>>> class InterruptingCallback(keras.callbacks.Callback):
... def on_epoch_begin(self, epoch, logs=None):
... if epoch == 4:
... raise RuntimeError('Interrupting!')
>>> callback = keras.callbacks.BackupAndRestore(backup_dir="/tmp/backup")
>>> model = keras.models.Sequential([keras.layers.Dense(10)])
>>> model.compile(keras.optimizers.SGD(), loss='mse')
>>> try:
... model.fit(np.arange(100).reshape(5, 20), np.zeros(5), epochs=10,
... batch_size=1, callbacks=[callback, InterruptingCallback()],
... verbose=0)
... except:
... pass
>>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
... epochs=10, batch_size=1, callbacks=[callback],
... verbose=0)
>>> # 只运行了 6 个更多的 epoch,因为第一次训练在零索引 epoch 4 时中断,第二次训练将从 4 到 9 继续.
>>> len(history.history['loss'])
>>> 6
```
参数:
- backup_dir: 字符串,存储恢复模型所需数据的目录路径.该目录不能在其他地方重复使用来存储其他文件,例如由另一个训练运行的 BackupAndRestore
回调,或同一训练运行的其他回调(例如 ModelCheckpoint
)使用.
- save_freq: "epoch"
、整数或 False
.当设置为 "epoch"
时,回调在每个 epoch 结束时保存检查点.当设置为整数时,回调每 save_freq
批次保存检查点.设置 save_freq=False
仅在使用抢占检查点(即 save_before_preemption=True
)时.
- delete_checkpoint: 布尔值.此 BackupAndRestore
回调通过保存检查点来备份训练状态.如果 delete_checkpoint=True
,训练完成后将删除检查点.如果希望保留检查点以供将来使用,请使用 False
.默认为 True
.