作者: fchollet
创建日期: 2016/01/06
最后修改: 2023/11/20
描述: 自定义层创建的演示。
这个示例展示了如何创建自定义层,使用反整流器层(最初在 2016 年 1 月作为 Keras 示例脚本提出),它是 ReLU 的一种替代方案。它通过将输入的负部分和正部分分开,并返回两个部分的绝对值的连接来避免将负数部分归零。这避免了信息的丢失,代价则是维度的增加。为了修复维度的增加,我们将特征线性组合回原始大小的空间。
import keras
from keras import layers
from keras import ops
要实现自定义层:
__init__
或 build()
中的 add_weight()
创建状态变量。类似地,您也可以创建子层。call()
方法,接收层的输入张量并返回输出张量。get_config()
来启用序列化,它返回一个配置字典。另请参见指南 通过子类化制作新层和模型。
class Antirectifier(layers.Layer):
def __init__(self, initializer="he_normal", **kwargs):
super().__init__(**kwargs)
self.initializer = keras.initializers.get(initializer)
def build(self, input_shape):
output_dim = input_shape[-1]
self.kernel = self.add_weight(
shape=(output_dim * 2, output_dim),
initializer=self.initializer,
name="kernel",
trainable=True,
)
def call(self, inputs):
inputs -= ops.mean(inputs, axis=-1, keepdims=True)
pos = ops.relu(inputs)
neg = ops.relu(-inputs)
concatenated = ops.concatenate([pos, neg], axis=-1)
mixed = ops.matmul(concatenated, self.kernel)
return mixed
def get_config(self):
# 实现 get_config 以启用序列化。这是可选的。
base_config = super().get_config()
config = {"initializer": keras.initializers.serialize(self.initializer)}
return dict(list(base_config.items()) + list(config.items()))
# 训练参数
batch_size = 128
num_classes = 10
epochs = 20
# 数据,分为训练集和测试集
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784)
x_test = x_test.reshape(-1, 784)
x_train = x_train.astype("float32")
x_test = x_test.astype("float32")
x_train /= 255
x_test /= 255
print(x_train.shape[0], "训练样本")
print(x_test.shape[0], "测试样本")
# 构建模型
model = keras.Sequential(
[
keras.Input(shape=(784,)),
layers.Dense(256),
Antirectifier(),
layers.Dense(256),
Antirectifier(),
layers.Dropout(0.5),
layers.Dense(10),
]
)
# 编译模型
model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.RMSprop(),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# 训练模型
model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.15)
# 测试模型
model.evaluate(x_test, y_test)
60000 训练样本
10000 测试样本
第 1 轮/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 2s 5ms/步 - 损失: 0.6226 - 稀疏分类准确率: 0.8146 - val_loss: 0.4256 - val_sparse_categorical_accuracy: 0.8808
第 2 轮/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.1887 - 稀疏分类准确率: 0.9455 - val_loss: 0.1556 - val_sparse_categorical_accuracy: 0.9588
第 3 轮/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.1406 - 稀疏分类准确率: 0.9608 - val_loss: 0.1531 - val_sparse_categorical_accuracy: 0.9611
第 4 轮/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.1084 - 稀疏分类准确率: 0.9691 - val_loss: 0.1178 - val_sparse_categorical_accuracy: 0.9731
第 5 轮/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0995 - 稀疏分类准确率: 0.9738 - val_loss: 0.2207 - val_sparse_categorical_accuracy: 0.9526
第 6 轮/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0831 - 稀疏分类准确率: 0.9769 - val_loss: 0.2092 - val_sparse_categorical_accuracy: 0.9533
第 7 轮/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0736 - 稀疏分类准确率: 0.9807 - val_loss: 0.1129 - val_sparse_categorical_accuracy: 0.9749
第 8 轮/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0653 - 稀疏分类准确率: 0.9827 - val_loss: 0.1000 - val_sparse_categorical_accuracy: 0.9791
第 9 轮/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0590 - 稀疏分类准确率: 0.9833 - val_loss: 0.1320 - val_sparse_categorical_accuracy: 0.9750
第 10 轮/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0587 - 稀疏分类准确率: 0.9854 - val_loss: 0.1439 - val_sparse_categorical_accuracy: 0.9747
第 11 轮/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0622 - 稀疏分类准确率: 0.9853 - val_loss: 0.1473 - val_sparse_categorical_accuracy: 0.9753
第 12 轮/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0554 - 稀疏分类准确率: 0.9869 - val_loss: 0.1529 - val_sparse_categorical_accuracy: 0.9757
第 13 轮/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/步 - 损失: 0.0507 - 稀疏分类准确率: 0.9884 - val_loss: 0.1452 - val_sparse_categorical_accuracy: 0.9783
第 14 轮/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0468 - 稀疏分类准确率: 0.9889 - val_loss: 0.1435 - val_sparse_categorical_accuracy: 0.9796
第 15 轮/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0478 - 稀疏分类准确率: 0.9892 - val_loss: 0.1580 - val_sparse_categorical_accuracy: 0.9770
第 16 轮/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0492 - 稀疏分类准确率: 0.9888 - val_loss: 0.1957 - val_sparse_categorical_accuracy: 0.9753
第 17 轮/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0478 - 稀疏分类准确率: 0.9896 - val_loss: 0.1865 - val_sparse_categorical_accuracy: 0.9779
第 18 轮/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0478 - 稀疏分类准确率: 0.9893 - val_loss: 0.2107 - val_sparse_categorical_accuracy: 0.9747
第 19 轮/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0494 - 稀疏分类准确率: 0.9894 - val_loss: 0.2306 - val_sparse_categorical_accuracy: 0.9734
第 20 轮/20
399/399 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/步 - 损失: 0.0473 - 稀疏分类准确率: 0.9910 - val_loss: 0.2201 - val_sparse_categorical_accuracy: 0.9731
313/313 ━━━━━━━━━━━━━━━━━━━━ 0s 802us/步 - 损失: 0.2086 - 稀疏分类准确率: 0.9710
[0.19070196151733398, 0.9740999937057495]