代码示例 / 快速Keras食谱 / Endpoint layer pattern

Endpoint layer pattern

作者: fchollet
创建日期: 2019/05/10
最后修改: 2023/11/22
描述: 演示“端点层”模式(处理损失管理的层)。

在 Colab 中查看 GitHub 源代码


设置

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import tensorflow as tf
import keras
import numpy as np

在功能 API 中使用端点层

“端点层”可以访问模型的目标,并在 call() 中使用 self.add_loss()Metric.update_state() 创建任意损失。 这使您能够定义与常规签名 fn(y_true, y_pred, sample_weight=None) 不匹配的损失和指标。

请注意,使用此模式,您可以为训练和评估设置单独的指标。

class LogisticEndpoint(keras.layers.Layer):
    def __init__(self, name=None):
        super().__init__(name=name)
        self.loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
        self.accuracy_metric = keras.metrics.BinaryAccuracy(name="accuracy")

    def call(self, logits, targets=None, sample_weight=None):
        if targets is not None:
            # 计算训练期间的损失值并将其添加
            # 到层中,使用 `self.add_loss()`。
            loss = self.loss_fn(targets, logits, sample_weight)
            self.add_loss(loss)

            # 将准确率记录为指标(我们可以记录任意指标,
            # 包括训练和推理的不同指标。)
            self.accuracy_metric.update_state(targets, logits, sample_weight)

        # 返回推理时的预测张量(用于 `.predict()`)。
        return tf.nn.softmax(logits)


inputs = keras.Input((764,), name="inputs")
logits = keras.layers.Dense(1)(inputs)
targets = keras.Input((1,), name="targets")
sample_weight = keras.Input((1,), name="sample_weight")
preds = LogisticEndpoint()(logits, targets, sample_weight)
model = keras.Model([inputs, targets, sample_weight], preds)

data = {
    "inputs": np.random.random((1000, 764)),
    "targets": np.random.random((1000, 1)),
    "sample_weight": np.random.random((1000, 1)),
}

model.compile(keras.optimizers.Adam(1e-3))
model.fit(data, epochs=2)
Epoch 1/2
 27/32 ━━━━━━━━━━━━━━━━━━━━  0s 2ms/step - loss: 0.3664   

WARNING: 所有在 absl::InitializeLog() 调用之前的日志消息都写入 STDERR
I0000 00:00:1700705222.380735 3351467 device_compiler.h:186] 使用 XLA 编译的集群!此行在进程的生命周期内最多记录一次。

 32/32 ━━━━━━━━━━━━━━━━━━━━ 2s 31ms/step - loss: 0.3663
Epoch 2/2
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.3627 

<keras.src.callbacks.history.History at 0x7f13401b1e10>

导出仅用于推理的模型

简单地不在模型中包含 targets。 权重保持不变。

inputs = keras.Input((764,), name="inputs")
logits = keras.layers.Dense(1)(inputs)
preds = LogisticEndpoint()(logits, targets=None, sample_weight=None)
inference_model = keras.Model(inputs, preds)

inference_model.set_weights(model.get_weights())

preds = inference_model.predict(np.random.random((1000, 764)))
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step

在子类模型中使用损失端点层

class LogReg(keras.Model):
    def __init__(self):
        super().__init__()
        self.dense = keras.layers.Dense(1)
        self.logistic_endpoint = LogisticEndpoint()

    def call(self, inputs):
        # 注意所有输入应该在第一个参数中
        # 因为我们希望能够调用 `model.fit(inputs)`。
        logits = self.dense(inputs["inputs"])
        preds = self.logistic_endpoint(
            logits=logits,
            targets=inputs["targets"],
            sample_weight=inputs["sample_weight"],
        )
        return preds


model = LogReg()
data = {
    "inputs": np.random.random((1000, 764)),
    "targets": np.random.random((1000, 1)),
    "sample_weight": np.random.random((1000, 1)),
}

model.compile(keras.optimizers.Adam(1e-3))
model.fit(data, epochs=2)
Epoch 1/2
 32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 0.3529
Epoch 2/2
 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.3509 

<keras.src.callbacks.history.History at 0x7f132c1d1450>