数据加载器

! [ -e /content ] && pip install -Uqq fastai  # 在Colab上升级fastai

DataLoader

from __future__ import annotations
from fastai.torch_basics import *
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter,_DatasetKind
_loaders = (_MultiProcessingDataLoaderIter,_SingleProcessDataLoaderIter)
from nbdev.showdoc import *
bs = 4
letters = list(string.ascii_lowercase)

DataLoader 辅助函数

fastai包含一个替代Pytorch的DataLoader,它在很大程度上与API兼容,并增加了许多有用的功能和灵活性。在我们查看这个类之前,有几个辅助函数需要定义。

def _wif(worker_id):
    set_num_threads(1)
    info = get_worker_info()
    ds = info.dataset.d
    ds.num_workers,ds.offs = info.num_workers,info.id
    set_seed(info.seed)
    ds.wif()

class _FakeLoader:
    def _fn_noops(self, x=None, *args, **kwargs): return x
    
    _IterableDataset_len_called,_auto_collation,collate_fn,drop_last = None,False,_fn_noops,False
    _index_sampler,generator,prefetch_factor,_get_shared_seed  = Inf.count,None,2,noop
    dataset_kind = _dataset_kind = _DatasetKind.Iterable
    
    def __init__(self, d, pin_memory, num_workers, timeout, persistent_workers,pin_memory_device):
        self.dataset,self.default,self.worker_init_fn,self.pin_memory_device = self,d,_wif,pin_memory_device
        store_attr('d,pin_memory,num_workers,timeout,persistent_workers,pin_memory_device')

    def __iter__(self): return iter(self.d.create_batches(self.d.sample()))

    @property
    def multiprocessing_context(self): return (None,multiprocessing)[self.num_workers>0]

    @contextmanager
    def no_multiproc(self):
        old_num_workers = self.num_workers
        try:
            self.num_workers = 0
            yield self.d
        finally: self.num_workers = old_num_workers

_collate_types = (ndarray, Tensor, typing.Mapping, str)
def fa_collate(t):
    "A replacement for PyTorch `default_collate` which maintains types and handles `Sequence`s"
    b = t[0]
    return (default_collate(t) if isinstance(b, _collate_types)
            else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)
            else default_collate(t))
#例如,x 是整数,y 是元组
t = [(1,(2,3)),(1,(2,3))]
test_eq(fa_collate(t), default_collate(t))
test_eq(L(fa_collate(t)).map(type), [Tensor,tuple])

t = [(1,(2,(3,4))),(1,(2,(3,4)))]
test_eq(fa_collate(t), default_collate(t))
test_eq(L(fa_collate(t)).map(type), [Tensor,tuple])
test_eq(L(fa_collate(t)[1]).map(type), [Tensor,tuple])
def fa_convert(t):
    "A replacement for PyTorch `default_convert` which maintains types and handles `Sequence`s"
    return (default_convert(t) if isinstance(t, _collate_types)
            else type(t)([fa_convert(s) for s in t]) if isinstance(t, Sequence)
            else default_convert(t))
t0 = array([1,2])
t = [t0,(t0,t0)]

test_eq(fa_convert(t), default_convert(t))
test_eq(L(fa_convert(t)).map(type), [Tensor,tuple])
class SkipItemException(Exception):
    "Raised to notify `DataLoader` to skip an item"
    pass
show_doc(SkipItemException, title_level=3)

source

SkipItemException

Raised to notify DataLoader to skip an item

def collate_error(e:Exception, batch):
    "Raises error when the batch could not collate, stating what items in the batch are different sizes and their types"
    err = f'Error when trying to collate the data into batches with fa_collate, at least two tensors in the batch are not the same size.\n\n'
    # 我们需要遍历整个批次,找出不匹配的地方。
    length = len(batch[0])
    for idx in range(length): # 对于批次中的每种类型
        for i, item in enumerate(batch):
            if i == 0: shape_a, type_a  = item[idx].shape, item[idx].__class__.__name__
            elif item[idx].shape != shape_a:
                shape_b = item[idx].shape
                if shape_a != shape_b:
                    err += f'Mismatch found on axis {idx} of the batch and is of type `{type_a}`:\n\tItem at index 0 has shape: {shape_a}\n\tItem at index {i} has shape: {shape_b}\n\nPlease include a transform in `after_item` that ensures all data of type {type_a} is the same size'
                    e.args = [err]
                    raise
batch = [torch.rand(3, 375, 500), torch.rand(3, 375, 500), torch.rand(3, 500, 333)]
with ExceptionExpected(RuntimeError, "Mismatch found on axis 0 of the batch and is of type `Tensor`"):
    try:
        fa_collate(batch)
    except Exception as e:
        collate_error(e, batch)

数据加载器 -

@funcs_kwargs
class DataLoader(GetAttr):
    _noop_methods = 'wif before_iter after_item before_batch after_batch after_iter'.split()
    for o in _noop_methods: exec(f"def {o}(self, x=None, *args, **kwargs): return x")
    _methods = _noop_methods + 'create_batches create_item create_batch retain \
        get_idxs sample shuffle_fn do_batch create_batch'.split()
    _default = 'dataset'
    def __init__(self, dataset=None, bs=None, num_workers=0, pin_memory=False, timeout=0, batch_size=None,
                 shuffle=False, drop_last=False, indexed=None, n=None, device=None, persistent_workers=False,
                 pin_memory_device='', **kwargs):
        if batch_size is not None: bs = batch_size # PyTorch compatibility
        assert not (bs is None and drop_last)
        if indexed is None: indexed = (hasattr(dataset,'__getitem__')
                                       and not isinstance(dataset, IterableDataset))
        if not indexed and shuffle: raise ValueError("Can only shuffle an indexed dataset (not an iterable one).")
        if n is None:
            try: n = len(dataset)
            except TypeError: pass
        store_attr('dataset,bs,shuffle,drop_last,indexed,n,pin_memory,timeout,device')
        self.rng,self.num_workers,self.offs = random.Random(random.randint(0,2**32-1)),1,0
        if sys.platform == "win32" and IN_NOTEBOOK and num_workers > 0: num_workers = 0       
        if sys.platform == "darwin" and num_workers > 0: num_workers = 0       
        self.fake_l = _FakeLoader(self, pin_memory, num_workers, timeout, persistent_workers=persistent_workers,
                                  pin_memory_device=pin_memory_device)

    def __len__(self):
        if self.n is None: raise TypeError
        if self.bs is None: return self.n
        return self.n//self.bs + (0 if self.drop_last or self.n%self.bs==0 else 1)

    def get_idxs(self):
        idxs = Inf.count if self.indexed else Inf.nones
        if self.n is not None: idxs = list(itertools.islice(idxs, self.n))
        if self.shuffle: idxs = self.shuffle_fn(idxs)
        return idxs
    
    def sample(self): 
        return (b for i,b in enumerate(self.__idxs) if i//(self.bs or 1)%self.num_workers==self.offs)

    def __iter__(self):
        self.randomize()
        self.before_iter()
        self.__idxs=self.get_idxs() # called in context of main process (not workers/subprocesses)
        for b in _loaders[self.fake_l.num_workers==0](self.fake_l):
            # pin_memory causes tuples to be converted to lists, so convert them back to tuples
            if self.pin_memory and type(b) == list: b = tuple(b)
            if self.device is not None: b = to_device(b, self.device)
            yield self.after_batch(b)
        self.after_iter()
        if hasattr(self, 'it'): del(self.it)

    def create_batches(self, samps):
        if self.dataset is not None: self.it = iter(self.dataset)
        res = filter(lambda o:o is not None, map(self.do_item, samps))
        yield from map(self.do_batch, self.chunkify(res))

    def new(self, dataset=None, cls=None, **kwargs):
        if dataset is None: dataset = self.dataset
        if cls is None: cls = type(self)
        cur_kwargs = dict(dataset=dataset, num_workers=self.fake_l.num_workers, pin_memory=self.pin_memory, timeout=self.timeout,
                          bs=self.bs, shuffle=self.shuffle, drop_last=self.drop_last, indexed=self.indexed, device=self.device)
        for n in self._methods:
            o = getattr(self, n)
            if not isinstance(o, MethodType): cur_kwargs[n] = o
        return cls(**merge(cur_kwargs, kwargs))

    @property
    def device(self) -> torch.device|None:
        return self._device

    @device.setter
    def device(self, device:int|str|torch.device|None):
        self._device, *_ = torch._C._nn._parse_to(device=device)
        if hasattr(self, 'after_batch') and hasattr(self.after_batch, 'fs'):
            for tfm in self.after_batch.fs:
                # Check that tfm.to is callable as TabularPandas & transforms set tfm.to as an object
                if hasattr(tfm, 'to') and callable(tfm.to): tfm.to(device)
                else:
                    for a in L(getattr(tfm, 'parameters', None)):
                        if hasattr(getattr(tfm, a), 'to'): setattr(tfm, a, getattr(tfm, a).to(device))

    @property
    def prebatched(self): return self.bs is None
    def do_item(self, s):
        try: return self.after_item(self.create_item(s))
        except SkipItemException: return None
    def chunkify(self, b): return b if self.prebatched else chunked(b, self.bs, self.drop_last)
    def shuffle_fn(self, idxs): return self.rng.sample(idxs, len(idxs))
    def randomize(self): self.rng = random.Random(self.rng.randint(0,2**32-1))
    def retain(self, res, b):  return retain_types(res, b[0] if is_listy(b) else b)
    def create_item(self, s):
        if self.indexed: return self.dataset[s or 0]
        elif s is None:  return next(self.it)
        else: raise IndexError("Cannot index an iterable dataset numerically - must use `None`.")
    def create_batch(self, b): 
        try: return (fa_collate,fa_convert)[self.prebatched](b)
        except Exception as e: 
            if not self.prebatched: collate_error(e,b)
            raise
    def do_batch(self, b): return self.retain(self.create_batch(self.before_batch(b)), b)
    def to(self, device): self.device = device
    def one_batch(self):
        if self.n is not None and len(self)==0: raise ValueError(f'This DataLoader does not contain any batches')
        with self.fake_l.no_multiproc(): res = first(self)
        if hasattr(self, 'it'): delattr(self, 'it')
        return res
add_docs(DataLoader, "API compatible with PyTorch DataLoader, with a lot more callbacks and flexibility",
         get_idxs       = "Return a list of indices to reference the dataset. Calls `shuffle_fn` internally if `shuffle=True`.",
         sample         = "Same as `get_idxs` but returns a generator of indices to reference the dataset.",
         create_batches = "Takes output of `sample` as input, and returns batches of data. Does not apply `after_batch`.",
         new            = "Create a new `DataLoader` with given arguments keeping remaining arguments same as original `DataLoader`.",
         prebatched     = "Check if `bs` is None.",
         do_item        = "Combines `after_item` and `create_item` to get an item from dataset by providing index as input.",
         chunkify       = "Used by `create_batches` to turn generator of items (`b`) into batches.",
         shuffle_fn     = "Returns a random permutation of `idxs`.",
         randomize      = "Set's `DataLoader` random number generator state.",
         retain         = "Cast each item of `res` to type of matching item in `b` if its a superclass.",
         create_item    = "Subset of the dataset containing the index values of sample if exists, else next iterator.",
         create_batch   = "Collate a list of items into a batch.",
         do_batch       = "Combines `create_batch` and `before_batch` to get a batch of items. Input is a list of items to collate.",
         to             = "Sets `self.device=device`.",
         one_batch      = "Return one batch from `DataLoader`.",
         wif            = "See pytorch `worker_init_fn` for details.", 
         before_iter    = "Called before `DataLoader` starts to read/iterate over the dataset.",
         after_item     = "Takes output of `create_item` as input and applies this function on it.",
         before_batch   = "It is called before collating a list of items into a batch. Input is a list of items.",
         after_batch    = "After collating mini-batch of items, the mini-batch is passed through this function.",
         after_iter     = "Called after `DataLoader` has fully read/iterated over the dataset.")

DataLoader 的参数:

  • dataset:用于加载数据的数据集。可以是映射式数据集或迭代式数据集。
  • bs(int):每个批次加载多少样本(如果提供了 batch_size,则 batch_size 会覆盖 bs)。如果 bs=None,则假定 dataset.__getitem__ 返回一个批次。
  • num_workers(int):用于数据加载的子进程数量。0 表示数据将在主进程中加载。
  • pin_memory(bool):如果为 True,数据加载器将在返回张量之前将其复制到 CUDA 针对内存中。
  • timeout(float>0):从工作进程收集一个批次的超时值(以秒为单位)。
  • batch_size(int):仅为 PyTorch 兼容性提供。使用 bs
  • shuffle(bool):如果为 True,则每次完全读取/迭代数据加载器时数据都会被洗牌。
  • drop_last(bool):如果为 True,则最后一个不完整的批次将被丢弃。
  • indexed(bool):DataLoader 将猜测数据集是否可以索引(或是否可迭代),但您可以用此参数覆盖它。默认为 True
  • n(int):默认为 len(dataset)。如果使用迭代式数据集,可以用 n 指定大小。
  • device(torch.device):默认为 default_device(),默认是 CUDA。您可以将设备指定为 torch.device('cpu')

重写 create_item 并使用默认的无限采样器以获取一个未知长度的流(当你想停止流时使用 stop())。

class RandDL(DataLoader):
    def create_item(self, s):
        r = random.random()
        return r if r<0.95 else stop()

L(RandDL())
(#9) [0.09071201211613367,0.03249811556595483,0.6517029228593939,0.8584412116263038,0.759838440232556,0.3725873327679504,0.1445316323722865,0.18876233969606782,0.25518635091544917]
L(RandDL(bs=4, drop_last=True)).map(len)
(#1) [4]
dl = RandDL(bs=4, num_workers=4, drop_last=True)
L(dl).map(len)
(#1) [4]
test_num_workers = 0 if sys.platform in ("win32","darwin") else 4
test_eq(dl.fake_l.num_workers, test_num_workers)
with dl.fake_l.no_multiproc(): 
    test_eq(dl.fake_l.num_workers, 0)
    L(dl).map(len)
test_eq(dl.fake_l.num_workers, test_num_workers)
def _rand_item(s):
    r = random.random()
    return r if r<0.95 else stop()

L(DataLoader(create_item=_rand_item))
(#2) [0.624781366539204,0.39823513973618685]

如果您不设置bs,则假设dataset提供一个迭代器或一个返回批次的__getitem__

ds1 = DataLoader(letters)
test_eq(L(ds1), letters)
test_eq(len(ds1), 26)

test_shuffled(L(DataLoader(letters, shuffle=True)), letters)

ds1 = DataLoader(letters, indexed=False)
test_eq(L(ds1), letters)
test_eq(len(ds1), 26)

t2 = L(tensor([0,1,2]),tensor([3,4,5]))
ds2 = DataLoader(t2)
test_eq_type(L(ds2), t2)

t3 = L(array([0,1,2], dtype=np.int64),array([3,4,5], dtype=np.int64))
ds3 = DataLoader(t3)
test_eq_type(L(ds3), t3.map(tensor))

ds4 = DataLoader(t3, create_batch=noop, after_iter=lambda: setattr(t3, 'f', 1))
test_eq_type(L(ds4), t3)
test_eq(t3.f, 1)

如果您设置了 bs,那么假设 dataset 提供一个迭代器或 __getitem__ 方法,该方法返回一个批次的单个项。

def twoepochs(d): return ' '.join(''.join(list(o)) for _ in range(2) for o in d)
ds1 = DataLoader(letters, bs=4, drop_last=True, num_workers=0)
test_eq(twoepochs(ds1), 'abcd efgh ijkl mnop qrst uvwx abcd efgh ijkl mnop qrst uvwx')

ds1 = DataLoader(letters,4,num_workers=2)
test_eq(twoepochs(ds1), 'abcd efgh ijkl mnop qrst uvwx yz abcd efgh ijkl mnop qrst uvwx yz')

ds1 = DataLoader(range(12), bs=4, num_workers=3)
test_eq_type(L(ds1), L(tensor([0,1,2,3]),tensor([4,5,6,7]),tensor([8,9,10,11])))

ds1 = DataLoader([str(i) for i in range(11)], bs=4, after_iter=lambda: setattr(t3, 'f', 2))
test_eq_type(L(ds1), L(['0','1','2','3'],['4','5','6','7'],['8','9','10']))
test_eq(t3.f, 2)

it = iter(DataLoader(map(noop,range(20)), bs=4, num_workers=1))
test_eq_type([next(it) for _ in range(3)], [tensor([0,1,2,3]),tensor([4,5,6,7]),tensor([8,9,10,11])])

可迭代的数据加载器需要特定的测试。

class DummyIterableDataset(IterableDataset):
    def __iter__(self):
        yield from range(11)

ds1 = DataLoader(DummyIterableDataset(), bs=4)
# 检查其效果良好,并确认我们可以进行多次处理。
for i in range(3):
    test_eq_type(L(ds1), L(tensor([0,1,2,3]),tensor([4,5,6,7]),tensor([8,9,10])))

# 检查 `drop_last` 功能正常(需进行多次遍历,因为这会提前终止迭代器)
ds1 = DataLoader(DummyIterableDataset(), bs=4, drop_last=True)
for i in range(3):
    test_eq_type(L(ds1), L(tensor([0,1,2,3]),tensor([4,5,6,7])))
class SleepyDL(list):
    def __getitem__(self,i):
        time.sleep(random.random()/50)
        return super().__getitem__(i)

t = SleepyDL(letters)

%time test_eq(DataLoader(t, num_workers=0), letters)
%time test_eq(DataLoader(t, num_workers=2), letters)
%time test_eq(DataLoader(t, num_workers=4), letters)

dl = DataLoader(t, shuffle=True, num_workers=1)
test_shuffled(L(dl), letters)
test_shuffled(L(dl), L(dl))
L(dl)
CPU times: user 3.35 ms, sys: 890 µs, total: 4.24 ms
Wall time: 307 ms
CPU times: user 6.93 ms, sys: 860 µs, total: 7.79 ms
Wall time: 333 ms
CPU times: user 7.78 ms, sys: 722 µs, total: 8.51 ms
Wall time: 331 ms
(#26) ['l','h','f','r','z','s','u','x','m','p'...]
class SleepyQueue():
    "Simulate a queue with varying latency"
    def __init__(self, q): self.q=q
    def __iter__(self):
        while True:
            time.sleep(random.random()/100)
            try: yield self.q.get_nowait()
            except queues.Empty: return

q = Queue()
for o in range(30): q.put(o)
it = SleepyQueue(q)

if not (sys.platform == "win32" and IN_NOTEBOOK):
    %time test_shuffled(L(DataLoader(it, num_workers=4)), L(range(30)))
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
File <timed eval>:1

File ~/git/fastcore/fastcore/test.py:73, in test_shuffled(a, b)
     71 def test_shuffled(a,b):
     72     "`test` that `a` and `b` are shuffled versions of the same sequence of items"
---> 73     test_ne(a, b)
     74     test_eq(Counter(a), Counter(b))

File ~/git/fastcore/fastcore/test.py:49, in test_ne(a, b)
     47 def test_ne(a,b):
     48     "`test` that `a!=b`"
---> 49     test(a,b,nequals,'!=')

File ~/git/fastcore/fastcore/test.py:27, in test(a, b, cmp, cname)
     25 "`assert` that `cmp(a,b)`; display inputs and `cname or cmp.__name__` if it fails"
     26 if cname is None: cname=cmp.__name__
---> 27 assert cmp(a,b),f"{cname}:\n{a}\n{b}"

AssertionError: !=:
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
class A(TensorBase): pass

for nw in (0,2):
    t = A(tensor([1,2]))
    dl = DataLoader([t,t,t,t,t,t,t,t], bs=4, num_workers=nw)
    b = first(dl)
    test_eq(type(b), A)

    t = (A(tensor([1,2])),)
    dl = DataLoader([t,t,t,t,t,t,t,t], bs=4, num_workers=nw)
    b = first(dl)
    test_eq(type(b[0]), A)
list(DataLoader(list(range(50)),bs=32,shuffle=True,num_workers=3))
[tensor([42, 12, 44, 21,  8,  6,  3, 37, 33,  9, 27, 34, 18, 26,  1, 23, 11, 41,
         15,  0, 49,  4, 38, 46, 48, 14, 40, 36, 17, 45, 30, 29]),
 tensor([19, 10, 22, 13, 25, 32, 35,  5,  2, 20, 47, 39, 16, 28, 43,  7, 31, 24])]
class A(TensorBase): pass
t = A(tensor(1,2))

tdl = DataLoader([t,t,t,t,t,t,t,t], bs=4, num_workers=2, after_batch=to_device)
b = first(tdl)
test_eq(type(b), A)

# 未知属性被委托给 `dataset`
test_eq(tdl.pop(), tensor(1,2))

覆盖 get_idxs 以返回相同的索引,直到数据加载器消耗完毕。这是为了测试在 num_workers > 1 时的一致采样行为。

class AdamantDL(DataLoader):
    def get_idxs(self):
        r=random.randint(0,self.n-1)
        return [r] * self.n

test_eq(torch.cat(tuple(AdamantDL((list(range(50))),bs=16,num_workers=4))).unique().numel(),1)

导出 -

from nbdev import nbdev_export
nbdev_export()
# 从子进程模块中导入Popen和PIPE
# 在脚本中测试 num_workers > 0 时,当 Python 进程启动方式为 spawn 时,功能正常。
# process = Popen(["python", "dltest.py"], stdout=PIPE)
# _, err = process.communicate(timeout=15)
# 退出码 = process.wait()
# test_eq(退出代码, 0)