Keras 3 API 文档 / 回调API / 基础回调类

基础回调类

[source]

Callback class

keras.callbacks.Callback()

用于构建新回调的基础类.

回调可以传递给 fit()evaluate()predict() 等 keras 方法,以便在模型训练、评估和推理生命周期的各个阶段进行挂钩.

要创建自定义回调,请继承 keras.callbacks.Callback 并重写与感兴趣阶段相关的方法.

示例:

```python
>>> training_finished = False
>>> class MyCallback(Callback):
...   def on_train_end(self, logs=None):
...     global training_finished
...     training_finished = True
>>> model = Sequential([
...     layers.Dense(1, input_shape=(1,))])
>>> model.compile(loss='mean_squared_error')
>>> model.fit(np.array([[1.0]]), np.array([[1.0]]),
...           callbacks=[MyCallback()])
>>> assert training_finished == True
如果你想在自定义训练循环中使用 `Callback` 对象:

1. 你应该将所有回调打包成一个 `callbacks.CallbackList`,以便它们可以一起被调用.
2. 你需要在你的循环中手动调用所有 `on_*` 方法,如下所示:

示例:

```python
callbacks =  keras.callbacks.CallbackList([...])
callbacks.append(...)
callbacks.on_train_begin(...)
for epoch in range(EPOCHS):
    callbacks.on_epoch_begin(epoch)
    for i, data in dataset.enumerate():
    callbacks.on_train_batch_begin(i)
    batch_logs = model.train_step(data)
    callbacks.on_train_batch_end(i, batch_logs)
    epoch_logs = ...
    callbacks.on_epoch_end(epoch, epoch_logs)
final_logs=...
callbacks.on_train_end(final_logs)

属性: params: 字典.训练参数(例如:冗长度、批量大小、epoch 数等). model: Model 实例.正在训练的模型的引用.

回调方法作为参数的 logs 字典将包含与当前批次或 epoch 相关的键(请参阅方法特定的文档字符串).