代码示例 / 快速Keras食谱 / 简单自定义层示例:反整流器

简单自定义层示例:反整流器

作者: fchollet
创建日期: 2016/01/06
最后修改: 2023/11/20
描述: 自定义层创建的演示。

在 Colab 中查看 GitHub 源码


介绍

这个示例展示了如何创建自定义层,使用反整流器层(最初在 2016 年 1 月作为 Keras 示例脚本提出),它是 ReLU 的一种替代方案。它通过将输入的负部分和正部分分开,并返回两个部分的绝对值的连接来避免将负数部分归零。这避免了信息的丢失,代价则是维度的增加。为了修复维度的增加,我们将特征线性组合回原始大小的空间。


设置

import keras
from keras import layers
from keras import ops

反整流器层

要实现自定义层:

  • 通过 __init__build() 中的 add_weight() 创建状态变量。类似地,您也可以创建子层。
  • 实现 call() 方法,接收层的输入张量并返回输出张量。
  • 可选地,您也可以通过实现 get_config() 来启用序列化,它返回一个配置字典。

另请参见指南 通过子类化制作新层和模型

class Antirectifier(layers.Layer):
    def __init__(self, initializer="he_normal", **kwargs):
        super().__init__(**kwargs)
        self.initializer = keras.initializers.get(initializer)

    def build(self, input_shape):
        output_dim = input_shape[-1]
        self.kernel = self.add_weight(
            shape=(output_dim * 2, output_dim),
            initializer=self.initializer,
            name="kernel",
            trainable=True,
        )

    def call(self, inputs):
        inputs -= ops.mean(inputs, axis=-1, keepdims=True)
        pos = ops.relu(inputs)
        neg = ops.relu(-inputs)
        concatenated = ops.concatenate([pos, neg], axis=-1)
        mixed = ops.matmul(concatenated, self.kernel)
        return mixed

    def get_config(self):
        # 实现 get_config 以启用序列化。这是可选的。
        base_config = super().get_config()
        config = {"initializer": keras.initializers.serialize(self.initializer)}
        return dict(list(base_config.items()) + list(config.items()))

让我们在 MNIST 上测试一下

# 训练参数
batch_size = 128
num_classes = 10
epochs = 20

# 数据,分为训练集和测试集
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

x_train = x_train.reshape(-1, 784)
x_test = x_test.reshape(-1, 784)
x_train = x_train.astype("float32")
x_test = x_test.astype("float32")
x_train /= 255
x_test /= 255
print(x_train.shape[0], "训练样本")
print(x_test.shape[0], "测试样本")

# 构建模型
model = keras.Sequential(
    [
        keras.Input(shape=(784,)),
        layers.Dense(256),
        Antirectifier(),
        layers.Dense(256),
        Antirectifier(),
        layers.Dropout(0.5),
        layers.Dense(10),
    ]
)

# 编译模型
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.RMSprop(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# 训练模型
model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.15)

# 测试模型
model.evaluate(x_test, y_test)
60000 训练样本
10000 测试样本
第 1 轮/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/步 - 损失: 0.6226 - 稀疏分类准确率: 0.8146 - val_loss: 0.4256 - val_sparse_categorical_accuracy: 0.8808
第 2 轮/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.1887 - 稀疏分类准确率: 0.9455 - val_loss: 0.1556 - val_sparse_categorical_accuracy: 0.9588
第 3 轮/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.1406 - 稀疏分类准确率: 0.9608 - val_loss: 0.1531 - val_sparse_categorical_accuracy: 0.9611
第 4 轮/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.1084 - 稀疏分类准确率: 0.9691 - val_loss: 0.1178 - val_sparse_categorical_accuracy: 0.9731
第 5 轮/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0995 - 稀疏分类准确率: 0.9738 - val_loss: 0.2207 - val_sparse_categorical_accuracy: 0.9526
第 6 轮/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0831 - 稀疏分类准确率: 0.9769 - val_loss: 0.2092 - val_sparse_categorical_accuracy: 0.9533
第 7 轮/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0736 - 稀疏分类准确率: 0.9807 - val_loss: 0.1129 - val_sparse_categorical_accuracy: 0.9749
第 8 轮/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0653 - 稀疏分类准确率: 0.9827 - val_loss: 0.1000 - val_sparse_categorical_accuracy: 0.9791
第 9 轮/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0590 - 稀疏分类准确率: 0.9833 - val_loss: 0.1320 - val_sparse_categorical_accuracy: 0.9750
第 10 轮/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0587 - 稀疏分类准确率: 0.9854 - val_loss: 0.1439 - val_sparse_categorical_accuracy: 0.9747
第 11 轮/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0622 - 稀疏分类准确率: 0.9853 - val_loss: 0.1473 - val_sparse_categorical_accuracy: 0.9753
第 12 轮/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0554 - 稀疏分类准确率: 0.9869 - val_loss: 0.1529 - val_sparse_categorical_accuracy: 0.9757
第 13 轮/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/步 - 损失: 0.0507 - 稀疏分类准确率: 0.9884 - val_loss: 0.1452 - val_sparse_categorical_accuracy: 0.9783
第 14 轮/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0468 - 稀疏分类准确率: 0.9889 - val_loss: 0.1435 - val_sparse_categorical_accuracy: 0.9796
第 15 轮/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0478 - 稀疏分类准确率: 0.9892 - val_loss: 0.1580 - val_sparse_categorical_accuracy: 0.9770
第 16 轮/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0492 - 稀疏分类准确率: 0.9888 - val_loss: 0.1957 - val_sparse_categorical_accuracy: 0.9753
第 17 轮/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0478 - 稀疏分类准确率: 0.9896 - val_loss: 0.1865 - val_sparse_categorical_accuracy: 0.9779
第 18 轮/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0478 - 稀疏分类准确率: 0.9893 - val_loss: 0.2107 - val_sparse_categorical_accuracy: 0.9747
第 19 轮/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0494 - 稀疏分类准确率: 0.9894 - val_loss: 0.2306 - val_sparse_categorical_accuracy: 0.9734
第 20 轮/20
 399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0473 - 稀疏分类准确率: 0.9910 - val_loss: 0.2201 - val_sparse_categorical_accuracy: 0.9731
 313/313 ━━━━━━━━━━━━━━━━━━━━ 0s 802us/步 - 损失: 0.2086 - 稀疏分类准确率: 0.9710

[0.19070196151733398, 0.9740999937057495]