Keras 3 API 文档 / 回调API / Lambda回调

Lambda回调

[source]

LambdaCallback class

keras.callbacks.LambdaCallback(
    on_epoch_begin=None,
    on_epoch_end=None,
    on_train_begin=None,
    on_train_end=None,
    on_train_batch_begin=None,
    on_train_batch_end=None,
    **kwargs
)

用于在运行时创建简单、自定义回调的回调函数.

该回调函数通过匿名函数构造,这些函数将在适当的时间被调用(在 `Model.{fit | evaluate | predict}` 期间).请注意,回调函数期望位置参数,如下所示:

- `on_epoch_begin` 和 `on_epoch_end` 期望两个位置参数:
  `epoch`, `logs`
- `on_train_begin` 和 `on_train_end` 期望一个位置参数:
  `logs`
- `on_train_batch_begin` 和 `on_train_batch_end` 期望两个位置
  参数:`batch`, `logs`
- 有关函数及其预期参数的完整列表,请参见 `Callback` 类定义.

参数:
    on_epoch_begin: 在每个 epoch 开始时调用.
    on_epoch_end: 在每个 epoch 结束时调用.
    on_train_begin: 在模型训练开始时调用.
    on_train_end: 在模型训练结束时调用.
    on_train_batch_begin: 在每个训练批次开始时调用.
    on_train_batch_end: 在每个训练批次结束时调用.
    kwargs: 您希望通过传递 `function_name=function` 来覆盖的 `Callback` 中的任何函数.例如,
        `LambdaCallback(.., on_train_end=train_end_fn)`.自定义函数
        需要具有与 `Callback` 中定义的相同的参数.

示例:

```python
# 在每个批次开始时打印批次编号.
batch_print_callback = LambdaCallback(
    on_train_batch_begin=lambda batch,logs: print(batch))

# 以 JSON 格式将 epoch 损失流式传输到文件中.文件内容
# 不是格式良好的 JSON,而是每行一个 JSON 对象.
import json
json_log = open('loss_log.json', mode='wt', buffering=1)
json_logging_callback = LambdaCallback(
    on_epoch_end=lambda epoch, logs: json_log.write(
        json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '

'), on_train_end=lambda logs: json_log.close() )

# 在模型训练完成后终止某些进程.
processes = ...
cleanup_callback = LambdaCallback(
    on_train_end=lambda logs: [
        p.terminate() for p in processes if p.is_alive()])

model.fit(...,
          callbacks=[batch_print_callback,
                     json_logging_callback,
                     cleanup_callback])
```