作者: Amogh Joshi
创建日期: 2021/06/02
最后修改: 2023/11/10
描述: 如何构建和训练一个卷积LSTM模型以进行下一帧视频预测。
卷积LSTM架构通过在LSTM层中引入卷积递归单元,将时间序列处理和计算机视觉结合在一起。在这个例子中,我们将探讨卷积LSTM模型在下一帧预测应用中的使用,即在给定一系列过去帧的情况下,预测视频帧接下来会是什么。
import numpy as np
import matplotlib.pyplot as plt
import keras
from keras import layers
import io
import imageio
from IPython.display import Image, display
from ipywidgets import widgets, Layout, HBox
在此示例中,我们将使用移动MNIST数据集。
我们将下载数据集,然后构建和预处理训练和验证集。
对于下一帧预测,我们的模型将使用前一帧,即f_n
,来预测新帧,称为f_(n + 1)
。为了使模型能够进行这些预测,我们需要处理数据,以便我们具有“移动”的输入和输出,其中输入数据是帧x_n
,用于预测帧y_(n + 1)
。
# 下载和加载数据集。
fpath = keras.utils.get_file(
"moving_mnist.npy",
"http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy",
)
dataset = np.load(fpath)
# 交换表示帧数和数据样本数的轴。
dataset = np.swapaxes(dataset, 0, 1)
# 我们将选择1000个例子并使用它们。
dataset = dataset[:1000, ...]
# 添加通道维度,因为图像是灰度的。
dataset = np.expand_dims(dataset, axis=-1)
# 使用索引分割为训练和验证集以优化内存。
indexes = np.arange(dataset.shape[0])
np.random.shuffle(indexes)
train_index = indexes[: int(0.9 * dataset.shape[0])]
val_index = indexes[int(0.9 * dataset.shape[0]) :]
train_dataset = dataset[train_index]
val_dataset = dataset[val_index]
# 将数据标准化到0-1范围。
train_dataset = train_dataset / 255
val_dataset = val_dataset / 255
# 我们将定义一个辅助函数来移动帧,其中
# `x`是帧0到n - 1,`y`是帧1到n。
def create_shifted_frames(data):
x = data[:, 0 : data.shape[1] - 1, :, :]
y = data[:, 1 : data.shape[1], :, :]
return x, y
# 将处理函数应用于数据集。
x_train, y_train = create_shifted_frames(train_dataset)
x_val, y_val = create_shifted_frames(val_dataset)
# 检查数据集。
print("训练数据集形状: " + str(x_train.shape) + ", " + str(y_train.shape))
print("验证数据集形状: " + str(x_val.shape) + ", " + str(y_val.shape))
下载数据从 http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy
819200096/819200096 ━━━━━━━━━━━━━━━━━━━━ 116s 0us/step
训练数据集形状: (900, 19, 64, 64, 1), (900, 19, 64, 64, 1)
验证数据集形状: (100, 19, 64, 64, 1), (100, 19, 64, 64, 1)
我们的数据由帧序列组成,每帧用于预测即将到来的帧。让我们查看一些这些序列帧。
# 构建一个图形以可视化图像。
fig, axes = plt.subplots(4, 5, figsize=(10, 8))
# 绘制一个随机数据示例的每个序列图像。
data_choice = np.random.choice(range(len(train_dataset)), size=1)[0]
for idx, ax in enumerate(axes.flat):
ax.imshow(np.squeeze(train_dataset[data_choice][idx]), cmap="gray")
ax.set_title(f"帧 {idx + 1}")
ax.axis("off")
# 打印信息并显示图形。
print(f"显示示例 {data_choice} 的帧。")
plt.show()
显示示例 95 的帧。
要构建卷积LSTM模型,我们将使用ConvLSTM2D
层,该层将接受形状为(batch_size, num_frames, width, height, channels)
的输入,并返回同样形状的预测电影。
# Construct the input layer with no definite frame size.
inp = layers.Input(shape=(None, *x_train.shape[2:]))
# We will construct 3 `ConvLSTM2D` layers with batch normalization,
# followed by a `Conv3D` layer for the spatiotemporal outputs.
x = layers.ConvLSTM2D(
filters=64,
kernel_size=(5, 5),
padding="same",
return_sequences=True,
activation="relu",
)(inp)
x = layers.BatchNormalization()(x)
x = layers.ConvLSTM2D(
filters=64,
kernel_size=(3, 3),
padding="same",
return_sequences=True,
activation="relu",
)(x)
x = layers.BatchNormalization()(x)
x = layers.ConvLSTM2D(
filters=64,
kernel_size=(1, 1),
padding="same",
return_sequences=True,
activation="relu",
)(x)
x = layers.Conv3D(
filters=1, kernel_size=(3, 3, 3), activation="sigmoid", padding="same"
)(x)
# Next, we will build the complete model and compile it.
model = keras.models.Model(inp, x)
model.compile(
loss=keras.losses.binary_crossentropy,
optimizer=keras.optimizers.Adam(),
)
在构建了我们的模型和数据之后,我们现在可以训练模型。
# 定义一些回调函数以改善训练。
early_stopping = keras.callbacks.EarlyStopping(monitor="val_loss", patience=10)
reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor="val_loss", patience=5)
# 定义可修改的训练超参数。
epochs = 20
batch_size = 5
# 将模型拟合到训练数据上。
model.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=epochs,
validation_data=(x_val, y_val),
callbacks=[early_stopping, reduce_lr],
)
Epoch 1/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 50s 226ms/step - loss: 0.1510 - val_loss: 0.2966 - learning_rate: 0.0010
Epoch 2/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0287 - val_loss: 0.1766 - learning_rate: 0.0010
Epoch 3/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0269 - val_loss: 0.0661 - learning_rate: 0.0010
Epoch 4/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0264 - val_loss: 0.0279 - learning_rate: 0.0010
Epoch 5/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0258 - val_loss: 0.0254 - learning_rate: 0.0010
Epoch 6/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0256 - val_loss: 0.0253 - learning_rate: 0.0010
Epoch 7/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0251 - val_loss: 0.0248 - learning_rate: 0.0010
Epoch 8/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0251 - val_loss: 0.0251 - learning_rate: 0.0010
Epoch 9/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0247 - val_loss: 0.0243 - learning_rate: 0.0010
Epoch 10/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0246 - val_loss: 0.0246 - learning_rate: 0.0010
Epoch 11/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0245 - val_loss: 0.0247 - learning_rate: 0.0010
Epoch 12/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0241 - val_loss: 0.0243 - learning_rate: 0.0010
Epoch 13/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0244 - val_loss: 0.0245 - learning_rate: 0.0010
Epoch 14/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0241 - val_loss: 0.0241 - learning_rate: 0.0010
Epoch 15/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0243 - val_loss: 0.0241 - learning_rate: 0.0010
Epoch 16/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0242 - val_loss: 0.0242 - learning_rate: 0.0010
Epoch 17/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0240 - val_loss: 0.0240 - learning_rate: 0.0010
Epoch 18/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0240 - val_loss: 0.0243 - learning_rate: 0.0010
Epoch 19/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0240 - val_loss: 0.0244 - learning_rate: 0.0010
Epoch 20/20
180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0237 - val_loss: 0.0238 - learning_rate: 1.0000e-04
<keras.src.callbacks.history.History at 0x7ff294f9c340>
在我们现在构建并训练模型之后,我们可以基于新视频生成一些示例帧预测。
我们将从验证集随机选择一个示例,然后选择前十帧。从那里,我们可以让模型预测10个新帧,并将其与真实帧预测进行比较。
# 从验证数据集中选择一个随机示例。
example = val_dataset[np.random.choice(range(len(val_dataset)), size=1)[0]]
# 从示例中选择前十帧。
frames = example[:10, ...]
original_frames = example[10:, ...]
# 预测一组新的10帧。
for _ in range(10):
# 提取模型的预测并进行后处理。
new_prediction = model.predict(np.expand_dims(frames, axis=0))
new_prediction = np.squeeze(new_prediction, axis=0)
predicted_frame = np.expand_dims(new_prediction[-1, ...], axis=0)
# 扩展预测帧的集合。
frames = np.concatenate((frames, predicted_frame), axis=0)
# 为原始帧和新帧构建一个图形。
fig, axes = plt.subplots(2, 10, figsize=(20, 4))
# 绘制原始帧。
for idx, ax in enumerate(axes[0]):
ax.imshow(np.squeeze(original_frames[idx]), cmap="gray")
ax.set_title(f"帧 {idx + 11}")
ax.axis("off")
# 绘制新帧。
new_frames = frames[10:, ...]
for idx, ax in enumerate(axes[1]):
ax.imshow(np.squeeze(new_frames[idx]), cmap="gray")
ax.set_title(f"帧 {idx + 11}")
ax.axis("off")
# 显示图形。
plt.show()
1/1 ━━━━━━━━━━━━━━━━━━━━ 2秒 2秒/步
1/1 ━━━━━━━━━━━━━━━━━━━━ 1秒 800毫秒/步
1/1 ━━━━━━━━━━━━━━━━━━━━ 1秒 805毫秒/步
1/1 ━━━━━━━━━━━━━━━━━━━━ 1秒 790毫秒/步
1/1 ━━━━━━━━━━━━━━━━━━━━ 1秒 821毫秒/步
1/1 ━━━━━━━━━━━━━━━━━━━━ 1秒 824毫秒/步
1/1 ━━━━━━━━━━━━━━━━━━━━ 1秒 928毫秒/步
1/1 ━━━━━━━━━━━━━━━━━━━━ 1秒 813毫秒/步
1/1 ━━━━━━━━━━━━━━━━━━━━ 1秒 810毫秒/步
1/1 ━━━━━━━━━━━━━━━━━━━━ 1秒 814毫秒/步
最后,我们将从验证集中挑选一些示例,并用它们构建一些 GIF,以查看模型的预测视频。
您可以使用托管在 Hugging Face Hub 上的训练模型,并在 Hugging Face Spaces 上尝试演示。
# 从数据集中选择一些随机示例。
examples = val_dataset[np.random.choice(range(len(val_dataset)), size=5)]
# 遍历示例并预测帧。
predicted_videos = []
for example in examples:
# 从示例中选择前/后十帧。
frames = example[:10, ...]
original_frames = example[10:, ...]
new_predictions = np.zeros(shape=(10, *frames[0].shape))
# 预测一组新的 10 帧。
for i in range(10):
# 提取模型的预测并进行后处理。
frames = example[: 10 + i + 1, ...]
new_prediction = model.predict(np.expand_dims(frames, axis=0))
new_prediction = np.squeeze(new_prediction, axis=0)
predicted_frame = np.expand_dims(new_prediction[-1, ...], axis=0)
# 扩展预测帧的集合。
new_predictions[i] = predicted_frame
# 为每一组真实值/预测图像创建和保存 GIF。
for frame_set in [original_frames, new_predictions]:
# 从选定的视频帧构建 GIF。
current_frames = np.squeeze(frame_set)
current_frames = current_frames[..., np.newaxis] * np.ones(3)
current_frames = (current_frames * 255).astype(np.uint8)
current_frames = list(current_frames)
# 从帧构建 GIF。
with io.BytesIO() as gif:
imageio.mimsave(gif, current_frames, "GIF", duration=200)
predicted_videos.append(gif.getvalue())
# 显示视频。
print(" 真实\t预测")
for i in range(0, len(predicted_videos), 2):
# 构建并显示一个 `HBox`,其中包含真实值和预测。
box = HBox(
[
widgets.Image(value=predicted_videos[i]),
widgets.Image(value=predicted_videos[i + 1]),
]
)
display(box)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 790ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step
真实预测
HBox(children=(Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\xff\xff\xff\xfe\xfe\xfe\xfd\xfd\xfd\xfc\xfc\xfc\xf8\…
HBox(children=(Image(value=b'GIF89a@\x00@\x00\x86\x00\x00\xff\xff\xff\xfd\xfd\xfd\xfc\xfc\xfc\xfb\xfb\xfb\xf4\…
HBox(children=(Image(value=b'GIF89a@\x00@\x00\x86\x00\x00\xff\xff\xff\xfe\xfe\xfe\xfd\xfd\xfd\xfc\xfc\xfc\xfb\…
HBox(children=(Image(value=b'GIF89a@\x00@\x00\x86\x00\x00\xff\xff\xff\xfe\xfe\xfe\xfd\xfd\xfd\xfc\xfc\xfc\xfb\…
HBox(children=(Image(value=b'GIF89a@\x00@\x00\x86\x00\x00\xff\xff\xff\xfd\xfd\xfd\xfc\xfc\xfc\xf9\xf9\xf9\xf7\…