使用 PyTorch 进行批量预测
内容
实时笔记本
您可以在 live session 中运行此笔记本,或查看 Github 上的内容。
使用 PyTorch 进行批量预测¶
[ ]:
%matplotlib inline
本示例遵循 Torch 的 迁移学习教程。我们将
在特定任务(蚂蚁 vs. 蜜蜂)上微调预训练的卷积神经网络。
使用 Dask 集群对该模型进行批量预测。
注意: 此示例的依赖项在Binder环境中默认未安装。您需要执行
!conda install torchvision pytorch-cpu
在单元格中安装必要的包。
主要关注点是使用 Dask 集群进行批量预测。
下载数据¶
PyTorch 文档托管了一小部分数据。我们将在本地下载并提取它。
[ ]:
import urllib.request
import zipfile
[ ]:
filename, _ = urllib.request.urlretrieve("https://download.pytorch.org/tutorial/hymenoptera_data.zip", "data.zip")
zipfile.ZipFile(filename).extractall()
目录看起来像
hymenoptera_data/
train/
ants/
0013035.jpg
...
1030023514_aad5c608f9.jpg
bees/
1092977343_cb42b38d62.jpg
...
2486729079_62df0920be.jpg
train/
ants/
0013025.jpg
...
1030023514_aad5c606d9.jpg
bees/
1092977343_cb42b38e62.jpg
...
2486729079_62df0921be.jpg
按照 教程 ,我们将对模型进行微调。
[ ]:
import torchvision
from tutorial_helper import (imshow, train_model, visualize_model,
dataloaders, class_names, finetune_model)
微调模型¶
我们的基础模型是 resnet18。它预测1,000个类别,而我们的模型只预测2个(蚂蚁或蜜蜂)。为了使该模型在examples.dask.org上快速训练,我们只使用几个epoch。
[ ]:
import dask
[ ]:
%%time
model = finetune_model()
在几张随机图片上看起来还不错:
[ ]:
visualize_model(model)
使用 Dask 进行批量预测¶
现在进入主要话题:在Dask集群上使用预训练模型进行批量预测。主要有两个复杂性,都与最小化数据移动量有关:
在工作者上加载数据。我们将使用
dask.delayed
在工作者的上加载数据,而不是在客户端上加载数据并将其发送到工作者。PyTorch 神经网络很大。 我们不希望它们出现在 Dask 任务图中,并且我们只希望它们移动一次。
[ ]:
from distributed import Client
client = Client(n_workers=2, threads_per_worker=2)
client
在工作者上加载数据¶
首先,我们将定义几个辅助函数来加载数据并为神经网络进行预处理。我们将在这里使用 dask.delayed
,以便执行是惰性的并且在集群上进行。有关使用 dask.delayed
的更多信息,请参见 延迟示例。
[ ]:
import glob
import toolz
import dask
import dask.array as da
import torch
from torchvision import transforms
from PIL import Image
@dask.delayed
def load(path, fs=__builtins__):
with fs.open(path, 'rb') as f:
img = Image.open(f).convert("RGB")
return img
@dask.delayed
def transform(img):
trn = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
return trn(img)
[ ]:
objs = [load(x) for x in glob.glob("hymenoptera_data/val/*/*.jpg")]
要从云存储加载数据,例如 Amazon S3,您可以使用
import s3fs
fs = s3fs.S3FileSystem(...)
objs = [load(x, fs=fs) for x in fs.glob(...)]
PyTorch 模型期望特定形状的张量,因此让我们对它们进行转换。
[ ]:
tensors = [transform(x) for x in objs]
模型期望输入批次,所以让我们将几个输入堆叠在一起。
[ ]:
batches = [dask.delayed(torch.stack)(batch)
for batch in toolz.partition_all(10, tensors)]
batches[:5]
最后,我们将编写一个小的 predict
辅助函数来预测输出类别(0 或 1)。
[ ]:
@dask.delayed
def predict(batch, model):
with torch.no_grad():
out = model(batch)
_, predicted = torch.max(out, 1)
predicted = predicted.numpy()
return predicted
移动模型¶
PyTorch 神经网络很大,所以我们不希望在任务图中多次重复它(每个批次一次)。
[ ]:
import pickle
dask.utils.format_bytes(len(pickle.dumps(model)))
相反,我们还会将模型本身包装在 dask.delayed
中。这意味着模型在 Dask 图中只会出现一次。
此外,由于我们在上面进行了微调(如果可用,这会在GPU上运行),我们应该将模型移回CPU。
[ ]:
dmodel = dask.delayed(model.cpu()) # ensuring model is on the CPU
现在我们将使用(延迟的)``predict`` 方法来获取我们的预测。
[ ]:
predictions = [predict(batch, dmodel) for batch in batches]
dask.visualize(predictions[:2])
可视化有点混乱,但大型 PyTorch 模型是 predict
任务的共同祖先的那个盒子。
现在,我们可以进行计算,使用 Dask 集群来完成所有工作。由于我们处理的数据集较小,因此可以安全地使用 dask.compute
将结果带回到本地客户端。对于更大的数据集,您可能希望将其写入磁盘或云存储,或者继续在集群上处理预测。
[ ]:
predictions = dask.compute(*predictions)
predictions
摘要¶
这个例子展示了如何使用 PyTorch 和 Dask 对一组图像进行批量预测。我们小心地在集群上远程加载数据,并且只序列化一次大型神经网络。