版权 2018 JAX 作者。
根据 Apache 许可证,版本 2.0(“许可证”);
根据 Apache 许可证第 2.0 版(“许可证”)授权; 除非遵守许可证,否则您不得使用此文件。 您可以在以下网址获取许可证的副本:
https://www.apache.org/licenses/LICENSE-2.0
除非适用法律要求或书面同意,否则根据许可证分发的软件是在“原样”基础上分发的, 没有任何形式的保证或条件,无论是明示还是暗示。 有关具体语言的许可和限制,请参阅许可证。
使用tensorflow/datasets数据加载训练一个简单的神经网络#
从 neural_network_and_data_loading.ipynb
分叉而来
让我们结合在快速入门中展示的所有内容来训练一个简单的神经网络。我们将首先在MNIST上指定并训练一个简单的多层感知机(MLP),使用JAX进行计算。我们将使用tensorflow/datasets
数据加载API来加载图像和标签(因为它非常优秀,世界还不需要另一个数据加载库 :P)。
当然,您可以将JAX与任何与NumPy兼容的API结合使用,使模型的指定变得更加即插即用。这里仅出于说明目的,我们不会使用任何神经网络库或特殊API来构建我们的模型。
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
超参数#
让我们先处理一些账务事项。
# 一个用于随机初始化权重和偏置的辅助函数
# 对于一个密集的神经网络层
def random_layer_params(m, n, key, scale=1e-2):
w_key, b_key = random.split(key)
return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))
# 初始化一个全连接神经网络的所有层,其大小为 "sizes"
def init_network_params(sizes, key):
keys = random.split(key, len(sizes))
return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.key(0))
自动批处理预测#
首先,我们定义我们的预测函数。请注意,我们是在为一个 单一 图像示例定义这个函数。我们将使用 JAX 的 vmap
函数来自动处理小批量数据,而不会带来性能损失。
from jax.scipy.special import logsumexp
def relu(x):
return jnp.maximum(0, x)
def predict(params, image):
# 逐样本预测
activations = image
for w, b in params[:-1]:
outputs = jnp.dot(w, activations) + b
activations = relu(outputs)
final_w, final_b = params[-1]
logits = jnp.dot(final_w, activations) + final_b
return logits - logsumexp(logits)
让我们检查一下我们的预测函数是否仅适用于单张图像。
# 这适用于单个示例
random_flattened_image = random.normal(random.key(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)
(10,)
# Doesn't work with a batch
random_flattened_images = random.normal(random.key(1), (10, 28 * 28))
try:
preds = predict(params, random_flattened_images)
except TypeError:
print('Invalid shapes!')
Invalid shapes!
# 让我们将其升级,以使用 `vmap` 处理批量数据
# 创建一个批处理版本的 `predict` 函数
batched_predict = vmap(predict, in_axes=(None, 0))
# `batched_predict` 的调用签名与 `predict` 相同。
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)
(10, 10)
在这一点上,我们已经具备了定义神经网络和训练所需的所有要素。我们构建了一个自动批处理版本的predict
,我们应该能够在损失函数中使用它。我们应该能够使用grad
来计算损失相对于神经网络参数的导数。最后,我们应该能够使用jit
来加速所有操作。
效用和损失函数#
def one_hot(x, k, dtype=jnp.float32):
"""创建一个大小为 k 的 x 的独热编码。"""
return jnp.array(x[:, None] == jnp.arange(k), dtype)
def accuracy(params, images, targets):
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
return jnp.mean(predicted_class == target_class)
def loss(params, images, targets):
preds = batched_predict(params, images)
return -jnp.mean(preds * targets)
@jit
def update(params, x, y):
grads = grad(loss)(params, x, y)
return [(w - step_size * dw, b - step_size * db)
for (w, b), (dw, db) in zip(params, grads)]
使用 tensorflow/datasets
加载数据#
JAX 专注于程序转换和基于加速器的 NumPy,因此我们不在 JAX 库中包含数据加载或预处理功能。已经有很多优秀的数据加载器可用,所以我们直接使用它们,而不是重新发明轮子。我们将使用 tensorflow/datasets
数据加载器。
import tensorflow as tf
# 确保TensorFlow无法识别GPU并占用所有GPU内存。
tf.config.set_visible_devices([], device_type='GPU')
import tensorflow_datasets as tfds
data_dir = '/tmp/tfds'
# 获取完整数据集以进行评估
# tfds.load 返回 tf.Tensors(如果 batch_size 不等于 -1,则返回 tf.data.Datasets)
# 你可以使用 tfds.dataset_as_numpy 将它们转换为 NumPy 数组(或 NumPy 数组的迭代器)。
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c
# 全套列车组
train_images, train_labels = train_data['image'], train_data['label']
train_images = jnp.reshape(train_images, (len(train_images), num_pixels))
train_labels = one_hot(train_labels, num_labels)
# 完整测试集
test_images, test_labels = test_data['image'], test_data['label']
test_images = jnp.reshape(test_images, (len(test_images), num_pixels))
test_labels = one_hot(test_labels, num_labels)
print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape)
Train: (60000, 784) (60000, 10)
Test: (10000, 784) (10000, 10)
训练循环#
import time
def get_train_batches():
# as_supervised=True 使我们得到的是一个 (图像, 标签) 的元组,而不是一个字典
ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)
# 你可以构建一个任意的 tf.data 输入管道。
ds = ds.batch(batch_size).prefetch(1)
# tfds.dataset_as_numpy 将 tf.data.Dataset 转换为 NumPy 数组的迭代器
return tfds.as_numpy(ds)
for epoch in range(num_epochs):
start_time = time.time()
for x, y in get_train_batches():
x = jnp.reshape(x, (len(x), num_pixels))
y = one_hot(y, num_labels)
params = update(params, x, y)
epoch_time = time.time() - start_time
train_acc = accuracy(params, train_images, train_labels)
test_acc = accuracy(params, test_images, test_labels)
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
Epoch 0 in 28.30 sec
Training set accuracy 0.8400499820709229
Test set accuracy 0.8469000458717346
Epoch 1 in 14.74 sec
Training set accuracy 0.8743667006492615
Test set accuracy 0.8803000450134277
Epoch 2 in 14.57 sec
Training set accuracy 0.8901500105857849
Test set accuracy 0.8957000374794006
Epoch 3 in 14.36 sec
Training set accuracy 0.8991333246231079
Test set accuracy 0.903700053691864
Epoch 4 in 14.20 sec
Training set accuracy 0.9061833620071411
Test set accuracy 0.9087000489234924
Epoch 5 in 14.89 sec
Training set accuracy 0.9113333225250244
Test set accuracy 0.912600040435791
Epoch 6 in 13.95 sec
Training set accuracy 0.9156833291053772
Test set accuracy 0.9176000356674194
Epoch 7 in 13.32 sec
Training set accuracy 0.9192000031471252
Test set accuracy 0.9214000701904297
Epoch 8 in 13.55 sec
Training set accuracy 0.9222500324249268
Test set accuracy 0.9241000413894653
Epoch 9 in 13.40 sec
Training set accuracy 0.9253666996955872
Test set accuracy 0.9269000291824341
我们现在已经使用了大部分的JAX API:grad
用于求导,jit
用于加速,以及 vmap
用于自动向量化。
我们使用NumPy来指定我们所有的计算,并借用了tensorflow/datasets
中的优秀数据加载器,并在GPU上运行了整个过程。