实时笔记本

您可以在 live session Binder 中运行此笔记本,或查看 Github 上的内容。

使用 PyTorch 进行批量预测

[ ]:
%matplotlib inline

本示例遵循 Torch 的 迁移学习教程。我们将

  1. 在特定任务(蚂蚁 vs. 蜜蜂)上微调预训练的卷积神经网络。

  2. 使用 Dask 集群对该模型进行批量预测。

注意: 此示例的依赖项在Binder环境中默认未安装。您需要执行

!conda install torchvision pytorch-cpu

在单元格中安装必要的包。

主要关注点是使用 Dask 集群进行批量预测。

下载数据

PyTorch 文档托管了一小部分数据。我们将在本地下载并提取它。

[ ]:
import urllib.request
import zipfile
[ ]:
filename, _ = urllib.request.urlretrieve("https://download.pytorch.org/tutorial/hymenoptera_data.zip", "data.zip")
zipfile.ZipFile(filename).extractall()

目录看起来像

hymenoptera_data/
    train/
        ants/
            0013035.jpg
            ...
            1030023514_aad5c608f9.jpg
        bees/
            1092977343_cb42b38d62.jpg
            ...
            2486729079_62df0920be.jpg

    train/
        ants/
            0013025.jpg
            ...
            1030023514_aad5c606d9.jpg
        bees/
            1092977343_cb42b38e62.jpg
            ...
            2486729079_62df0921be.jpg

按照 教程 ,我们将对模型进行微调。

[ ]:
import torchvision
from tutorial_helper import (imshow, train_model, visualize_model,
                             dataloaders, class_names, finetune_model)

微调模型

我们的基础模型是 resnet18。它预测1,000个类别,而我们的模型只预测2个(蚂蚁或蜜蜂)。为了使该模型在examples.dask.org上快速训练,我们只使用几个epoch。

[ ]:
import dask
[ ]:
%%time
model = finetune_model()

在几张随机图片上看起来还不错:

[ ]:
visualize_model(model)

使用 Dask 进行批量预测

现在进入主要话题:在Dask集群上使用预训练模型进行批量预测。主要有两个复杂性,都与最小化数据移动量有关:

  1. 在工作者上加载数据。我们将使用 dask.delayed 在工作者的上加载数据,而不是在客户端上加载数据并将其发送到工作者。

  2. PyTorch 神经网络很大。 我们不希望它们出现在 Dask 任务图中,并且我们只希望它们移动一次。

[ ]:
from distributed import Client

client = Client(n_workers=2, threads_per_worker=2)
client

在工作者上加载数据

首先,我们将定义几个辅助函数来加载数据并为神经网络进行预处理。我们将在这里使用 dask.delayed,以便执行是惰性的并且在集群上进行。有关使用 dask.delayed 的更多信息,请参见 延迟示例

[ ]:
import glob
import toolz
import dask
import dask.array as da
import torch
from torchvision import transforms
from PIL import Image


@dask.delayed
def load(path, fs=__builtins__):
    with fs.open(path, 'rb') as f:
        img = Image.open(f).convert("RGB")
        return img


@dask.delayed
def transform(img):
    trn = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    return trn(img)
[ ]:
objs = [load(x) for x in glob.glob("hymenoptera_data/val/*/*.jpg")]

要从云存储加载数据,例如 Amazon S3,您可以使用

import s3fs

fs = s3fs.S3FileSystem(...)
objs = [load(x, fs=fs) for x in fs.glob(...)]

PyTorch 模型期望特定形状的张量,因此让我们对它们进行转换。

[ ]:
tensors = [transform(x) for x in objs]

模型期望输入批次,所以让我们将几个输入堆叠在一起。

[ ]:
batches = [dask.delayed(torch.stack)(batch)
           for batch in toolz.partition_all(10, tensors)]
batches[:5]

最后,我们将编写一个小的 predict 辅助函数来预测输出类别(0 或 1)。

[ ]:
@dask.delayed
def predict(batch, model):
    with torch.no_grad():
        out = model(batch)
        _, predicted = torch.max(out, 1)
        predicted = predicted.numpy()
    return predicted

移动模型

PyTorch 神经网络很大,所以我们不希望在任务图中多次重复它(每个批次一次)。

[ ]:
import pickle

dask.utils.format_bytes(len(pickle.dumps(model)))

相反,我们还会将模型本身包装在 dask.delayed 中。这意味着模型在 Dask 图中只会出现一次。

此外,由于我们在上面进行了微调(如果可用,这会在GPU上运行),我们应该将模型移回CPU。

[ ]:
dmodel = dask.delayed(model.cpu()) # ensuring model is on the CPU

现在我们将使用(延迟的)``predict`` 方法来获取我们的预测。

[ ]:
predictions = [predict(batch, dmodel) for batch in batches]
dask.visualize(predictions[:2])

可视化有点混乱,但大型 PyTorch 模型是 predict 任务的共同祖先的那个盒子。

现在,我们可以进行计算,使用 Dask 集群来完成所有工作。由于我们处理的数据集较小,因此可以安全地使用 dask.compute 将结果带回到本地客户端。对于更大的数据集,您可能希望将其写入磁盘或云存储,或者继续在集群上处理预测。

[ ]:
predictions = dask.compute(*predictions)
predictions

摘要

这个例子展示了如何使用 PyTorch 和 Dask 对一组图像进行批量预测。我们小心地在集群上远程加载数据,并且只序列化一次大型神经网络。