作者: Smerity 等
创建日期: 2015/08/17
最后修改: 2024/02/13
描述: 一个学习将数字字符串相加的模型,例如 "535+61" -> "596"。
在这个示例中,我们训练一个模型来学习将两个数字(以字符串形式提供)相加。
示例:
输入可以选择性地反向,这在许多任务中显示出提高性能,具体见: 学习执行 和 利用神经网络进行序列到序列学习。
从理论上讲,序列顺序反转为此问题引入了较短的短期依赖关系。
结果:
对于两位数(反向):
三位数(反向):
四位数(反向):
五位数(反向):
import keras
from keras import layers
import numpy as np
# 模型和数据集的参数。
TRAINING_SIZE = 50000
DIGITS = 3
REVERSE = True
# 输入的最大长度为 'int + int'(例如,'345+678')。整数的最大长度为
# DIGITS。
MAXLEN = DIGITS + 1 + DIGITS
class CharacterTable:
"""给定一组字符:
+ 将它们编码为单热编码的整数表示
+ 将单热编码或整数表示解码为其字符输出
+ 将概率向量解码为其字符输出
"""
def __init__(self, chars):
"""初始化字符表。
# 参数
chars: 输入中可能出现的字符。
"""
self.chars = sorted(set(chars))
self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
self.indices_char = dict((i, c) for i, c in enumerate(self.chars))
def encode(self, C, num_rows):
"""对给定字符串 C 进行单热编码。
# 参数
C: 要编码的字符串。
num_rows: 返回的单热编码中的行数。用于保持每个数据的行数相同。
"""
x = np.zeros((num_rows, len(self.chars)))
for i, c in enumerate(C):
x[i, self.char_indices[c]] = 1
return x
def decode(self, x, calc_argmax=True):
"""将给定向量或二维数组解码为其字符输出。
# 参数
x: 概率或单热表示的向量或二维数组;或者字符索引的向量(用于 `calc_argmax=False`)。
calc_argmax: 是否寻找具有最大
概率的字符索引,默认值为 `True`。
"""
if calc_argmax:
x = x.argmax(axis=-1)
return "".join(self.indices_char[x] for x in x)
# 所有数字,加号和用于填充的空格。
chars = "0123456789+ "
ctable = CharacterTable(chars)
questions = []
expected = []
seen = set()
print("生成数据中...")
while len(questions) < TRAINING_SIZE:
f = lambda: int(
"".join(
np.random.choice(list("0123456789"))
for i in range(np.random.randint(1, DIGITS + 1))
)
)
a, b = f(), f()
# 跳过我们已经见过的加法问题
# 也跳过任何使得 x+Y == Y+x 的问题(因此进行排序)。
key = tuple(sorted((a, b)))
if key in seen:
continue
seen.add(key)
# 用空格填充数据,使其始终为 MAXLEN。
q = "{}+{}".format(a, b)
query = q + " " * (MAXLEN - len(q))
ans = str(a + b)
# 答案的最大大小可以是 DIGITS + 1。
ans += " " * (DIGITS + 1 - len(ans))
if REVERSE:
# 反转查询,例如:'12+345 ' 变为 ' 543+21'。(注意填充使用的空间。)
query = query[::-1]
questions.append(query)
expected.append(ans)
print("总问题数:", len(questions))
生成数据中...
总问题数: 50000
print("Vectorization...")
x = np.zeros((len(questions), MAXLEN, len(chars)), dtype=bool)
y = np.zeros((len(questions), DIGITS + 1, len(chars)), dtype=bool)
for i, sentence in enumerate(questions):
x[i] = ctable.encode(sentence, MAXLEN)
for i, sentence in enumerate(expected):
y[i] = ctable.encode(sentence, DIGITS + 1)
# Shuffle (x, y) in unison as the later parts of x will almost all be larger
# digits.
indices = np.arange(len(y))
np.random.shuffle(indices)
x = x[indices]
y = y[indices]
# Explicitly set apart 10% for validation data that we never train over.
split_at = len(x) - len(x) // 10
(x_train, x_val) = x[:split_at], x[split_at:]
(y_train, y_val) = y[:split_at], y[split_at:]
print("Training Data:")
print(x_train.shape)
print(y_train.shape)
print("Validation Data:")
print(x_val.shape)
print(y_val.shape)
向量化...
训练数据:
(45000, 7, 12)
(45000, 4, 12)
验证数据:
(5000, 7, 12)
(5000, 4, 12)
print("构建模型...")
num_layers = 1 # 尝试添加更多的 LSTM 层!
model = keras.Sequential()
# 使用 LSTM "编码" 输入序列,产生大小为 128 的输出。
# 注意:在输入序列长度可变的情况下,
# 使用 input_shape=(None, num_feature)。
model.add(layers.Input((MAXLEN, len(chars))))
model.add(layers.LSTM(128))
# 将 RNN 的最后输出多次提供作为解码器 RNN 的输入,重复 'DIGITS + 1' 次,因为这是输出的最大长度,
# 例如,当 DIGITS=3 时,最大输出为 999+999=1998。
model.add(layers.RepeatVector(DIGITS + 1))
# 解码器 RNN 可以是多个层堆叠或单个层。
for _ in range(num_layers):
# 通过将 return_sequences 设置为 True,返回不仅是最后的输出,而是
# 迄今为止所有的输出,形式为 (num_samples, timesteps,
# output_dim)。这是必要的,因为下面的 TimeDistributed 需要
# 第一维是时间步。
model.add(layers.LSTM(128, return_sequences=True))
# 对于输入的每个时间切片应用一个密集层。对于输出序列的每个步骤,
# 决定应该选择哪个字符。
model.add(layers.Dense(len(chars), activation="softmax"))
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
model.summary()
构建模型...
模型: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃ 层 (类型) ┃ 输出形状 ┃ 参数 # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ lstm (LSTM) │ (None, 128) │ 72,192 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ repeat_vector (RepeatVector) │ (None, 4, 128) │ 0 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ lstm_1 (LSTM) │ (None, 4, 128) │ 131,584 │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense (Dense) │ (None, 4, 12) │ 1,548 │ └─────────────────────────────────┴───────────────────────────┴────────────┘
总参数: 205,324 (802.05 KB)
可训练参数: 205,324 (802.05 KB)
非可训练参数: 0 (0.00 B)
# 训练参数.
epochs = 30
batch_size = 32
# 结果显示的格式字符.
green_color = "\033[92m"
red_color = "\033[91m"
end_char = "\033[0m"
# 每代训练模型,并对验证数据集展示预测结果.
for epoch in range(1, epochs):
print()
print("迭代", epoch)
model.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=1,
validation_data=(x_val, y_val),
)
# 随机从验证集中选择10个样本,以便我们可以可视化错误.
for i in range(10):
ind = np.random.randint(0, len(x_val))
rowx, rowy = x_val[np.array([ind])], y_val[np.array([ind])]
preds = np.argmax(model.predict(rowx, verbose=0), axis=-1)
q = ctable.decode(rowx[0])
correct = ctable.decode(rowy[0])
guess = ctable.decode(preds[0], calc_argmax=False)
print("Q", q[::-1] if REVERSE else q, end=" ")
print("T", correct, end=" ")
if correct == guess:
print(f"{green_color}☑ {guess}{end_char}")
else:
print(f"{red_color}☒ {guess}{end_char}")
迭代 1
1407/1407 ━━━━━━━━━━━━━━━━━━━━ 10s 6ms/step - accuracy: 0.3258 - loss: 1.8801 - val_accuracy: 0.4268 - val_loss: 1.5506
Q 499+58 T 557 ☒ 511
Q 51+638 T 689 ☒ 662
Q 87+12 T 99 ☒ 11
Q 259+55 T 314 ☒ 561
Q 704+87 T 791 ☒ 811
Q 988+67 T 1055 ☒ 101
Q 94+116 T 210 ☒ 111
Q 724+4 T 728 ☒ 777
Q 8+673 T 681 ☒ 772
Q 8+991 T 999 ☒ 900
迭代 2
1407/1407 ━━━━━━━━━━━━━━━━━━━━ 8s 6ms/step - accuracy: 0.4688 - loss: 1.4235 - val_accuracy: 0.5846 - val_loss: 1.1293
Q 379+6 T 385 ☒ 387
Q 15+504 T 519 ☒ 525
Q 552+299 T 851 ☒ 727
Q 664+0 T 664 ☒ 667
Q 500+257 T 757 ☒ 797
Q 50+818 T 868 ☒ 861
Q 310+691 T 1001 ☒ 900
Q 378+548 T 926 ☒ 827
Q 46+59 T 105 ☒ 122
Q 49+817 T 866 ☒ 871
迭代 3
1407/1407 ━━━━━━━━━━━━━━━━━━━━ 8s 6ms/step - accuracy: 0.6053 - loss: 1.0648 - val_accuracy: 0.6665 - val_loss: 0.9070
Q 1+266 T 267 ☒ 260
Q 73+257 T 330 ☒ 324
Q 421+628 T 1049 ☒ 1022
Q 85+590 T 675 ☒ 660
Q 66+34 T 100 ☒ 90
Q 256+639 T 895 ☒ 890
Q 6+677 T 683 ☑ 683
Q 162+637 T 799 ☒ 792
Q 5+324 T 329 ☒ 337
Q 848+34 T 882 ☒ 889
迭代 4
1407/1407 ━━━━━━━━━━━━━━━━━━━━ 8s 5ms/step - accuracy: 0.6781 - loss: 0.8751 - val_accuracy: 0.7037 - val_loss: 0.8092
Q 677+1 T 678 ☒ 676
Q 1+531 T 532 ☒ 535
Q 699+60 T 759 ☒ 756
Q 475+139 T 614 ☒ 616
Q 327+592 T 919 ☒ 915
Q 48+912 T 960 ☒ 956
Q 520+78 T 598 ☒ 505
Q 318+8 T 326 ☒ 327
Q 914+53 T 967 ☒ 966
Q 734+0 T 734 ☒ 733
迭代 5
1407/1407 ━━━━━━━━━━━━━━━━━━━━ 8s 6ms/step - accuracy: 0.7142 - loss: 0.7807 - val_accuracy: 0.7164 - val_loss: 0.7622
Q 150+337 T 487 ☒ 489
Q 72+934 T 1006 ☒ 1005
Q 171+62 T 233 ☒ 231
Q 108+21 T 129 ☒ 135
Q 755+896 T 1651 ☒ 1754
Q 117+1 T 118 ☒ 119
Q 148+95 T 243 ☒ 241
Q 719+956 T 1675 ☒ 1684
Q 656+43 T 699 ☒ 695
Q 368+8 T 376 ☒ 372
迭代 6
1407/1407 ━━━━━━━━━━━━━━━━━━━━ 8s 5ms/step - accuracy: 0.7377 - loss: 0.7157 - val_accuracy: 0.7541 - val_loss: 0.6684
Q 945+364 T 1309 ☒ 1305
Q 762+96 T 858 ☒ 855
Q 5+650 T 655 ☑ 655
Q 52+680 T 732 ☒ 735
Q 77+724 T 801 ☒ 800
Q 46+739 T 785 ☑ 785
Q 843+43 T 886 ☒ 885
Q 158+3 T 161 ☒ 160
Q 426+711 T 1137 ☒ 1138
Q 157+41 T 198 ☒ 190
迭代 7
1407/1407 ━━━━━━━━━━━━━━━━━━━━ 8s 6ms/step - accuracy: 0.7642 - loss: 0.6462 - val_accuracy: 0.7955 - val_loss: 0.5433
Q 822+27 T 849 ☑ 849
Q 82+495 T 577 ☒ 563
Q 9+366 T 375 ☒ 373
Q 9+598 T 607 ☒ 696
Q 186+41 T 227 ☒ 226
Q 920+920 T 1840 ☒ 1846
Q 445+345 T 790 ☒ 797
Q 783+588 T 1371 ☒ 1360
Q 36+473 T 509 ☒ 502
Q 354+61 T 415 ☒ 416
迭代 8
1407/1407 ━━━━━━━━━━━━━━━━━━━━ 8s 6ms/step - accuracy: 0.8326 - loss: 0.4626 - val_accuracy: 0.9069 - val_loss: 0.2744
Q 458+154 T 612 ☑ 612
Q 309+19 T 328 ☑ 328
Q 808+97 T 905 ☑ 905
Q 28+736 T 764 ☑ 764
Q 28+79 T 107 ☑ 107
Q 44+84 T 128 ☒ 129
Q 744+13 T 757 ☑ 757
Q 24+996 T 1020 ☒ 1011
Q 8+193 T 201 ☒ 101
Q 483+9 T 492 ☒ 491
第9次迭代
1407/1407 ━━━━━━━━━━━━━━━━━━━━ 8s 6ms/步 - 准确率: 0.9365 - 损失: 0.2275 - 验证准确率: 0.9657 - 验证损失: 0.1393
Q 330+61 T 391 ☑ 391
Q 207+82 T 289 ☒ 299
Q 23+234 T 257 ☑ 257
Q 690+567 T 1257 ☑ 1257
Q 293+97 T 390 ☒ 380
Q 312+868 T 1180 ☑ 1180
Q 956+40 T 996 ☑ 996
Q 97+105 T 202 ☒ 203
Q 365+44 T 409 ☑ 409
Q 76+639 T 715 ☑ 715
迭代 10
1407/1407 ━━━━━━━━━━━━━━━━━━━━ 7s 5ms/step - 准确率: 0.9717 - 损失: 0.1223 - 验证集准确率: 0.9744 - 验证集损失: 0.0965
Q 123+143 T 266 ☑ 266
Q 599+1 T 600 ☑ 600
Q 729+237 T 966 ☑ 966
Q 51+120 T 171 ☑ 171
Q 97+672 T 769 ☑ 769
Q 840+5 T 845 ☑ 845
Q 86+494 T 580 ☒ 570
Q 278+51 T 329 ☑ 329
Q 8+832 T 840 ☑ 840
Q 383+9 T 392 ☑ 392
迭代 11
1407/1407 ━━━━━━━━━━━━━━━━━━━━ 7s 5ms/step - 准确率: 0.9842 - 损失: 0.0729 - 验证集准确率: 0.9808 - 验证集损失: 0.0690
Q 181+923 T 1104 ☑ 1104
Q 747+24 T 771 ☑ 771
Q 6+65 T 71 ☑ 71
Q 75+994 T 1069 ☑ 1069
Q 712+587 T 1299 ☑ 1299
Q 977+10 T 987 ☑ 987
Q 742+24 T 766 ☑ 766
Q 215+44 T 259 ☑ 259
Q 817+683 T 1500 ☑ 1500
Q 102+48 T 150 ☒ 140
迭代 12
1407/1407 ━━━━━━━━━━━━━━━━━━━━ 8s 6ms/step - 准确率: 0.9820 - 损失: 0.0695 - 验证集准确率: 0.9823 - 验证集损失: 0.0596
Q 819+885 T 1704 ☒ 1604
Q 34+20 T 54 ☑ 54
Q 9+996 T 1005 ☑ 1005
Q 915+811 T 1726 ☑ 1726
Q 166+640 T 806 ☑ 806
Q 229+82 T 311 ☑ 311
Q 1+418 T 419 ☑ 419
Q 552+28 T 580 ☑ 580
Q 279+733 T 1012 ☑ 1012
Q 756+734 T 1490 ☑ 1490
迭代 13
1407/1407 ━━━━━━━━━━━━━━━━━━━━ 8s 6ms/step - 准确率: 0.9836 - 损失: 0.0587 - 验证集准确率: 0.9941 - 验证集损失: 0.0296
Q 793+0 T 793 ☑ 793
Q 79+48 T 127 ☑ 127
Q 484+92 T 576 ☑ 576
Q 39+655 T 694 ☑ 694
Q 64+708 T 772 ☑ 772
Q 568+341 T 909 ☑ 909
Q 9+918 T 927 ☑ 927
Q 48+912 T 960 ☑ 960
Q 31+289 T 320 ☑ 320
Q 378+548 T 926 ☑ 926
迭代 14
1407/1407 ━━━━━━━━━━━━━━━━━━━━ 8s 5ms/step - 准确率: 0.9915 - 损失: 0.0353 - 验证集准确率: 0.9901 - 验证集损失: 0.0358
Q 318+8 T 326 ☒ 325
Q 886+63 T 949 ☒ 959
Q 77+8 T 85 ☑ 85
Q 418+40 T 458 ☑ 458
Q 30+32 T 62 ☑ 62
Q 541+93 T 634 ☑ 634
Q 6+7 T 13 ☒ 14
Q 670+74 T 744 ☑ 744
Q 97+57 T 154 ☑ 154
Q 60+13 T 73 ☑ 73
迭代 15
1407/1407 ━━━━━━━━━━━━━━━━━━━━ 8s 6ms/step - 准确率: 0.9911 - 损失: 0.0335 - 验证集准确率: 0.9934 - 验证集损失: 0.0262
Q 24+533 T 557 ☑ 557
Q 324+44 T 368 ☑ 368
Q 63+505 T 568 ☑ 568
Q 670+74 T 744 ☑ 744
Q 58+359 T 417 ☑ 417
Q 16+428 T 444 ☑ 444
Q 17+99 T 116 ☑ 116
Q 779+903 T 1682 ☑ 1682
Q 40+576 T 616 ☑ 616
Q 947+773 T 1720 ☑ 1720
迭代 16
1407/1407 ━━━━━━━━━━━━━━━━━━━━ 8s 5ms/step - 准确率: 0.9968 - 损失: 0.0175 - 验证集准确率: 0.9901 - 验证集损失: 0.0360
Q 315+155 T 470 ☑ 470
Q 594+950 T 1544 ☑ 1544
Q 372+37 T 409 ☑ 409
Q 537+47 T 584 ☑ 584
Q 8+263 T 271 ☑ 271
Q 81+500 T 581 ☑ 581
Q 75+270 T 345 ☑ 345
Q 0+796 T 796 ☑ 796
Q 655+965 T 1620 ☑ 1620
Q 384+1 T 385 ☑ 385
迭代 17
1407/1407 ━━━━━━━━━━━━━━━━━━━━ 8s 5ms/step - 准确率: 0.9972 - 损失: 0.0148 - 验证集准确率: 0.9924 - 验证集损失: 0.0278
Q 168+83 T 251 ☑ 251
Q 951+53 T 1004 ☑ 1004
Q 400+37 T 437 ☑ 437
Q 996+473 T 1469 ☒ 1569
Q 996+847 T 1843 ☑ 1843
Q 842+550 T 1392 ☑ 1392
Q 479+72 T 551 ☑ 551
Q 753+782 T 1535 ☑ 1535
Q 99+188 T 287 ☑ 287
Q 2+974 T 976 ☑ 976
````
</div>
<div class="k-default-codeblock">
</div>
<div class="k-default-codeblock">
</div>
<div class="k-default-codeblock">
</div>
<div class="k-default-codeblock">
</div>
<div class="k-default-codeblock">
</div>
<div class="k-default-codeblock">
</div>
<div class="k-default-codeblock">
</div>
<div class="k-default-codeblock">
</div>
<div class="k-default-codeblock">
</div>
<div class="k-default-codeblock">
</div>
<div class="k-default-codeblock">
</div>
<div class="k-default-codeblock">
你将在 ~30 个周期后达到 99+% 的验证准确率。