Keras 3 API 文档 / 回调API / 交换EMA权重

交换EMA权重

[source]

SwapEMAWeights class

keras.callbacks.SwapEMAWeights(swap_on_epoch=False)

在评估之前和之后交换模型权重和EMA权重.

此回调在模型评估之前用优化器的EMA权重(过去模型权重值的指数移动平均值,实现"Polyak平均”)替换模型权重值,并在评估之后恢复之前的权重.

SwapEMAWeights回调应与设置use_ema=True的优化器一起使用.

注意,为了节省内存,权重是就地交换的.如果你在其他回调中修改EMA权重或模型权重,行为是未定义的.

示例:

# 记得在优化器中设置`use_ema=True`
optimizer = SGD(use_ema=True)
model.compile(optimizer=optimizer, loss=..., metrics=...)

# 使用EMA权重计算指标
model.fit(X_train, Y_train, callbacks=[SwapEMAWeights()])

# 如果你想用EMA权重保存模型检查点,可以设置`swap_on_epoch=True`并将ModelCheckpoint放在SwapEMAWeights之后.
model.fit(
    X_train,
    Y_train,
    callbacks=[SwapEMAWeights(swap_on_epoch=True), ModelCheckpoint(...)]
)

参数: swap_on_epoch:是否在on_epoch_begin()on_epoch_end()时执行交换.如果你想使用EMA权重进行其他回调(如ModelCheckpoint),这很有用.默认为False.