Callback
classkeras.callbacks.Callback()
用于构建新回调的基础类.
回调可以传递给 fit()
、evaluate()
和 predict()
等 keras 方法,以便在模型训练、评估和推理生命周期的各个阶段进行挂钩.
要创建自定义回调,请继承 keras.callbacks.Callback
并重写与感兴趣阶段相关的方法.
示例:
```python
>>> training_finished = False
>>> class MyCallback(Callback):
... def on_train_end(self, logs=None):
... global training_finished
... training_finished = True
>>> model = Sequential([
... layers.Dense(1, input_shape=(1,))])
>>> model.compile(loss='mean_squared_error')
>>> model.fit(np.array([[1.0]]), np.array([[1.0]]),
... callbacks=[MyCallback()])
>>> assert training_finished == True
如果你想在自定义训练循环中使用 `Callback` 对象:
1. 你应该将所有回调打包成一个 `callbacks.CallbackList`,以便它们可以一起被调用.
2. 你需要在你的循环中手动调用所有 `on_*` 方法,如下所示:
示例:
```python
callbacks = keras.callbacks.CallbackList([...])
callbacks.append(...)
callbacks.on_train_begin(...)
for epoch in range(EPOCHS):
callbacks.on_epoch_begin(epoch)
for i, data in dataset.enumerate():
callbacks.on_train_batch_begin(i)
batch_logs = model.train_step(data)
callbacks.on_train_batch_end(i, batch_logs)
epoch_logs = ...
callbacks.on_epoch_end(epoch, epoch_logs)
final_logs=...
callbacks.on_train_end(final_logs)
属性:
params: 字典.训练参数(例如:冗长度、批量大小、epoch 数等).
model: Model
实例.正在训练的模型的引用.
回调方法作为参数的 logs
字典将包含与当前批次或 epoch 相关的键(请参阅方法特定的文档字符串).