使用 PyTorch 数据加载训练简单神经网络

使用 PyTorch 数据加载训练简单神经网络#

在 Colab 中打开 在 Kaggle 中打开

版权 © 2018 JAX 作者。

根据 Apache 许可证版本 2.0(”许可证”)进行许可;除非遵守许可证,否则您不得使用此文件。 您可以在以下网址获取许可证的副本:

https://www.apache.org/licenses/LICENSE-2.0

除非适用法律要求或书面同意,否则根据许可证分发的软件是以“原样”基础进行分发,不附带任何形式的担保或条件,无论明示或暗示。 有关许可证下特定语言的权限和限制,请参阅许可证。

JAX

让我们结合在快速入门中展示的所有内容来训练一个简单的神经网络。我们将首先指定并训练一个简单的多层感知机(MLP)在MNIST数据集上,使用JAX进行计算。我们将使用PyTorch的数据加载API来加载图像和标签(因为它非常出色,世界上不需要另一个数据加载库)。

当然,您可以使用与NumPy兼容的任何API来使模型的指定更加即插即用。在这里,仅为了说明的目的,我们将不使用任何神经网络库或构建模型的特殊API。

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

超参数#

让我们先处理一些基本的事项。

# 一个用于随机初始化权重和偏置的辅助函数
# 对于一个密集的神经网络层
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# 初始化一个全连接神经网络的所有层,其大小为 "sizes"
def init_network_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 8
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.key(0))

自动批量预测#

让我们首先定义我们的预测函数。请注意,我们是为一个_单一_图像示例定义这个函数的。我们将使用JAX的vmap函数来自动处理小批量,不会产生性能损失。

from jax.scipy.special import logsumexp

def relu(x):
  return jnp.maximum(0, x)

def predict(params, image):
  # 逐样本预测
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)

  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits - logsumexp(logits)

我们来检查一下我们的预测函数是否仅适用于单个图像。

# 这适用于单个示例
random_flattened_image = random.normal(random.key(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)
(10,)
# Doesn't work with a batch
random_flattened_images = random.normal(random.key(1), (10, 28 * 28))
try:
  preds = predict(params, random_flattened_images)
except TypeError:
  print('Invalid shapes!')
Invalid shapes!
# 让我们将其升级,以使用 `vmap` 处理批量数据

# 创建一个批处理版本的 `predict` 函数
batched_predict = vmap(predict, in_axes=(None, 0))

# `batched_predict` 的调用签名与 `predict` 相同。
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)
(10, 10)

在这一点上,我们已经具备了定义神经网络和训练所需的所有成分。我们构建了一个自动批处理的 predict 版本,我们应该能够在损失函数中使用它。我们应该能够使用 grad 来计算损失对神经网络参数的导数。最后,我们应该能够使用 jit 来加速所有操作。

公用性和损失函数#

def one_hot(x, k, dtype=jnp.float32):
  """创建一个大小为 k 的 x 的独热编码。"""
  return jnp.array(x[:, None] == jnp.arange(k), dtype)

def accuracy(params, images, targets):
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
  preds = batched_predict(params, images)
  return -jnp.mean(preds * targets)

@jit
def update(params, x, y):
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

使用PyTorch进行数据加载#

JAX 专注于程序转换和加速器支持的 NumPy,因此我们不会在 JAX 库中包含数据加载或数据处理功能。已经有很多优秀的数据加载器可供使用,所以我们就不必重新发明轮子了。我们将使用 PyTorch 的数据加载器,并制作一个小的适配器来使其与 NumPy 数组兼容。

!pip install torch torchvision
Requirement already satisfied: torch in /opt/anaconda3/lib/python3.7/site-packages (1.4.0)
Requirement already satisfied: torchvision in /opt/anaconda3/lib/python3.7/site-packages (0.5.0)
Requirement already satisfied: numpy in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (1.17.2)
Requirement already satisfied: six in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (1.12.0)
Requirement already satisfied: pillow>=4.1.1 in /opt/anaconda3/lib/python3.7/site-packages (from torchvision) (6.2.0)
import numpy as np
from jax.tree_util import tree_map
from torch.utils import data
from torchvision.datasets import MNIST

def numpy_collate(batch):
  return tree_map(np.asarray, data.default_collate(batch))

class NumpyLoader(data.DataLoader):
  def __init__(self, dataset, batch_size=1,
                shuffle=False, sampler=None,
                batch_sampler=None, num_workers=0,
                pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
    super(self.__class__, self).__init__(dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        collate_fn=numpy_collate,
        pin_memory=pin_memory,
        drop_last=drop_last,
        timeout=timeout,
        worker_init_fn=worker_init_fn)

class FlattenAndCast(object):
  def __call__(self, pic):
    return np.ravel(np.array(pic, dtype=jnp.float32))
# 使用torch数据集定义我们的数据集
mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw
Processing...
Done!
# 获取完整训练数据集(用于训练时检查准确性)
train_images = np.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)
train_labels = one_hot(np.array(mnist_dataset.train_labels), n_targets)

# 获取完整测试数据集
mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)
test_images = jnp.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=jnp.float32)
test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets)

/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:55: UserWarning: train_data has been renamed data
  warnings.warn("train_data has been renamed data")
/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:45: UserWarning: train_labels has been renamed targets
  warnings.warn("train_labels has been renamed targets")
/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:60: UserWarning: test_data has been renamed data
  warnings.warn("test_data has been renamed data")
/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:50: UserWarning: test_labels has been renamed targets
  warnings.warn("test_labels has been renamed targets")

训练循环#

import time

for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in training_generator:
    y = one_hot(y, n_targets)
    params = update(params, x, y)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))
Epoch 0 in 55.15 sec
Training set accuracy 0.9157500267028809
Test set accuracy 0.9195000529289246
Epoch 1 in 42.26 sec
Training set accuracy 0.9372166991233826
Test set accuracy 0.9384000301361084
Epoch 2 in 44.37 sec
Training set accuracy 0.9491666555404663
Test set accuracy 0.9469000697135925
Epoch 3 in 41.75 sec
Training set accuracy 0.9568166732788086
Test set accuracy 0.9534000158309937
Epoch 4 in 41.16 sec
Training set accuracy 0.9631333351135254
Test set accuracy 0.9577000737190247
Epoch 5 in 38.89 sec
Training set accuracy 0.9675000309944153
Test set accuracy 0.9616000652313232
Epoch 6 in 40.68 sec
Training set accuracy 0.9708333611488342
Test set accuracy 0.9650000333786011
Epoch 7 in 41.50 sec
Training set accuracy 0.973716676235199
Test set accuracy 0.9672000408172607

我们现在已使用了整个 JAX API:grad 用于求导,jit 用于加速,vmap 用于自动向量化。我们使用 NumPy 来指定所有的计算,并借用了 PyTorch 中出色的数据加载器,并在 GPU 上运行了整个过程。