数据核心

from nbdev.cli import *
! [ -e /content ] && pip install -Uqq fastai  # 在Colab上升级fastai
from __future__ import annotations
from fastai.torch_basics import *
from fastai.data.load import *
from nbdev.showdoc import *

收集数据的核心功能

这里的类提供了将一系列变换应用于一组项目(TfmdListsDatasets)或一个DataLoaderTfmdDl)的功能,以及用于收集模型训练数据的基类:DataLoaders

TfmdDL -

@typedispatch
def show_batch(
    x, # 批处理中的输入
    y, # 批处理中的目标(s)
    samples, # 长度为 `max_n` 的 (`x`, `y`) 对列表
    ctxs=None, # 要显示数据的`ctx`对象列表。可以是matplotlib轴、DataFrame等。
    max_n=9, # 显示的最大`样本`数量
    **kwargs
):
    "Show `max_n` input(s) and target(s) from the batch."
    if ctxs is None: ctxs = Inf.nones
    if hasattr(samples[0], 'show'):
        ctxs = [s.show(ctx=c, **kwargs) for s,c,_ in zip(samples,ctxs,range(max_n))]
    else:
        for i in range_of(samples[0]):
            ctxs = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(i),ctxs,range(max_n))]
    return ctxs

show_batch 是一个类型分派函数,负责显示解码后的 samplesxy 分别是要显示的批次中的输入和目标,并根据它们的类型进行分派。如果 xTensorImageTensorText,则 show_batch 会有不同的实现(详情请参见 vision.core 或 text.data)。ctxs 可以被传递,但该函数负责在必要时创建它们。kwargs 取决于具体的实现。

@typedispatch
def show_results(
    x, # 批处理中的输入
    y, # 批处理中的目标(s)
    samples, # 长度为 `max_n` 的 (`x`, `y`) 对列表
    outs, # 模型预测的输出列表
    ctxs=None, # 要显示数据的`ctx`对象列表。可以是matplotlib轴、DataFrame等。
    max_n=9, # 显示的最大`样本`数量
    **kwargs
):
    "Show `max_n` results with input(s), target(s) and prediction(s)."
    if ctxs is None: ctxs = Inf.nones
    for i in range(len(samples[0])):
        ctxs = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(i),ctxs,range(max_n))]
    for i in range(len(outs[0])):
        ctxs = [b.show(ctx=c, **kwargs) for b,c,_ in zip(outs.itemgot(i),ctxs,range(max_n))]
    return ctxs

show_results 是一个类型分发函数,负责显示解码后的 samples 及其对应的 outs。与 show_batch 类似,xy 是要显示的批次中的输入和目标,并根据它们的类型进行分发。可以传递 ctxs,但如果需要,函数有责任自己创建它们。kwargs 取决于具体的实现。

_all_ = ["show_batch", "show_results"]
_batch_tfms = ('after_item','before_batch','after_batch')
class TfmdDL(DataLoader):
    "Transformed `DataLoader`"
    @delegates(DataLoader.__init__)
    def __init__(self,
        dataset, # 从映射或可迭代样式的数据集中加载数据
        bs:int=64, # 批量大小
        shuffle:bool=False, # 是否打乱数据
        num_workers:int=None, # 并行使用的 CPU 核心数量(默认:所有可用核心,最多 16 个)
        verbose:bool=False, # 是否打印详细日志
        do_setup:bool=True, # 是否为批量转换运行 `setup()`
        **kwargs
    ):
        if num_workers is None: num_workers = min(16, defaults.cpus)
        for nm in _batch_tfms: kwargs[nm] = Pipeline(kwargs.get(nm,None))
        super().__init__(dataset, bs=bs, shuffle=shuffle, num_workers=num_workers, **kwargs)
        if do_setup:
            for nm in _batch_tfms:
                pv(f"Setting up {nm}: {kwargs[nm]}", verbose)
                kwargs[nm].setup(self)

    def _one_pass(self):
        b = self.do_batch([self.do_item(None)])
        if self.device is not None: b = to_device(b, self.device)
        its = self.after_batch(b)
        self._n_inp = 1 if not isinstance(its, (list,tuple)) or len(its)==1 else len(its)-1
        self._types = explode_types(its)

    def _retain_dl(self,b):
        if not getattr(self, '_types', None): self._one_pass()
        return retain_types(b, typs=self._types)

    @delegates(DataLoader.new)
    def new(self, 
        dataset=None, # 从映射或可迭代样式的数据集中加载数据
        cls=None, # 新创建的 `DataLoader` 对象的类
        **kwargs
    ):
        res = super().new(dataset, cls, do_setup=False, **kwargs)
        if not hasattr(self, '_n_inp') or not hasattr(self, '_types'):
            try:
                self._one_pass()
                res._n_inp,res._types = self._n_inp,self._types
            except Exception as e: 
                print("Could not do one pass in your dataloader, there is something wrong in it. Please see the stack trace below:")
                raise
        else: res._n_inp,res._types = self._n_inp,self._types
        return res

    def before_iter(self):
        super().before_iter()
        split_idx = getattr(self.dataset, 'split_idx', None)
        for nm in _batch_tfms:
            f = getattr(self,nm)
            if isinstance(f,Pipeline): f.split_idx=split_idx

    def decode(self, 
        b # 批量解码
    ):
        return to_cpu(self.after_batch.decode(self._retain_dl(b)))
    def decode_batch(self, 
        b, # 批量解码
        max_n:int=9, # 最大解码项目数
        full:bool=True # 是否解码所有变换。如果为 `False`,则解码至项目知道如何显示自身为止。
    ): 
        return self._decode_batch(self.decode(b), max_n, full)

    def _decode_batch(self, b, max_n=9, full=True):
        f = self.after_item.decode
        f1 = self.before_batch.decode
        f = compose(f1, f, partial(getcallable(self.dataset,'decode'), full = full))
        return L(batch_to_samples(b, max_n=max_n)).map(f)

    def _pre_show_batch(self, b, max_n=9):
        "Decode `b` to be ready for `show_batch`"
        b = self.decode(b)
        if hasattr(b, 'show'): return b,None,None
        its = self._decode_batch(b, max_n, full=False)
        if not is_listy(b): b,its = [b],L((o,) for o in its)
        return detuplify(b[:self.n_inp]),detuplify(b[self.n_inp:]),its

    def show_batch(self,
        b=None, # 批量展示
        max_n:int=9, # 显示的最大项目数
        ctxs=None, # 要显示数据的`ctx`对象列表。可以是matplotlib轴、DataFrame等
        show:bool=True, # 是否显示数据
        unique:bool=False, # 是否仅显示一个 
        **kwargs
    ):
        "Show `max_n` input(s) and target(s) from the batch."
        if unique:
            old_get_idxs = self.get_idxs
            self.get_idxs = lambda: Inf.zeros
        if b is None: b = self.one_batch()
        if not show: return self._pre_show_batch(b, max_n=max_n)
        show_batch(*self._pre_show_batch(b, max_n=max_n), ctxs=ctxs, max_n=max_n, **kwargs)
        if unique: self.get_idxs = old_get_idxs

    def show_results(self, 
        b, # 批量显示结果
        out, # 模型对该批次的预测输出
        max_n:int=9, # 显示的最大项目数
        ctxs=None, # 要显示数据的`ctx`对象列表。可以是matplotlib轴、DataFrame等
        show:bool=True, # 是否显示数据
        **kwargs
    ):
        "Show `max_n` results with input(s), target(s) and prediction(s)."
        x,y,its = self.show_batch(b, max_n=max_n, show=False)
        b_out = type(b)(b[:self.n_inp] + (tuple(out) if is_listy(out) else (out,)))
        x1,y1,outs = self.show_batch(b_out, max_n=max_n, show=False)
        res = (x,x1,None,None) if its is None else (x, y, its, outs.itemgot(slice(self.n_inp,None)))
        if not show: return res
        show_results(*res, ctxs=ctxs, max_n=max_n, **kwargs)

    @property
    def n_inp(self) -> int:
        "Number of elements in `Datasets` or `TfmdDL` tuple to be considered part of input."
        if hasattr(self.dataset, 'n_inp'): return self.dataset.n_inp
        if not hasattr(self, '_n_inp'): self._one_pass()
        return self._n_inp

TfmdDL 是一个 DataLoader,它从一系列 Transform 创建 Pipeline,用于回调 after_itembefore_batchafter_batch。因此,它可以解码或显示处理后的 batch

add_docs(TfmdDL,
         decode="Decode `b` using `tfms`",
         decode_batch="Decode `b` entirely",
         new="Create a new version of self with a few changed attributes",
         show_batch="Show `b` (defaults to `one_batch`), a list of lists of pipeline outputs (i.e. output of a `DataLoader`)",
         show_results="Show each item of `b` and `out`",
         before_iter="override",
         to="Put self and its transforms state on `device`")
class _Category(int, ShowTitle): pass
#测试保留类型
class NegTfm(Transform):
    def encodes(self, x): return torch.neg(x)
    def decodes(self, x): return torch.neg(x)
    
tdl = TfmdDL([(TensorImage([1]),)] * 4, after_batch=NegTfm(), bs=4, num_workers=4)
b = tdl.one_batch()
test_eq(type(b[0]), TensorImage)
b = (tensor([1.,1.,1.,1.]),)
test_eq(type(tdl.decode_batch(b)[0][0]), TensorImage)
class A(Transform): 
    def encodes(self, x): return x 
    def decodes(self, x): return TitledInt(x) 

@Transform
def f(x)->None: return fastuple((x,x))

start = torch.arange(50)
test_eq_type(f(2), fastuple((2,2)))
a = A()
tdl = TfmdDL(start, after_item=lambda x: (a(x), f(x)), bs=4)
x,y = tdl.one_batch()
test_eq(type(y), fastuple)

s = tdl.decode_batch((x,y))
test_eq(type(s[0][1]), fastuple)
tdl = TfmdDL(torch.arange(0,50), after_item=A(), after_batch=NegTfm(), bs=4)
test_eq(tdl.dataset[0], start[0])
test_eq(len(tdl), (50-1)//4+1)
test_eq(tdl.bs, 4)
test_stdout(tdl.show_batch, '0\n1\n2\n3')
test_stdout(partial(tdl.show_batch, unique=True), '0\n0\n0\n0')
class B(Transform):
    parameters = 'a'
    def __init__(self): self.a = torch.tensor(0.)
    def encodes(self, x): x
    
tdl = TfmdDL([(TensorImage([1]),)] * 4, after_batch=B(), bs=4)
test_eq(tdl.after_batch.fs[0].a.device, torch.device('cpu'))
tdl.to(default_device())
test_eq(tdl.after_batch.fs[0].a.device, default_device())

方法

show_doc(TfmdDL.one_batch)

DataLoader.one_batch[source]

DataLoader.one_batch()

Return one batch from DataLoader.

tfm = NegTfm()
tdl = TfmdDL(start, after_batch=tfm, bs=4)
b = tdl.one_batch()
test_eq(tensor([0,-1,-2,-3]), b)
show_doc(TfmdDL.decode)

TfmdDL.decode[source]

TfmdDL.decode(b)

Decode b using tfms

test_eq(tdl.decode(b), tensor(0,1,2,3))
show_doc(TfmdDL.decode_batch)

TfmdDL.decode_batch[source]

TfmdDL.decode_batch(b, max_n=9, full=True)

Decode b entirely

test_eq(tdl.decode_batch(b), [0,1,2,3])
show_doc(TfmdDL.show_batch)

TfmdDL.show_batch[source]

TfmdDL.show_batch(b=None, max_n=9, ctxs=None, show=True, unique=False, **kwargs)

Show b (defaults to one_batch), a list of lists of pipeline outputs (i.e. output of a DataLoader)

show_doc(TfmdDL.to)

TfmdDL.to[source]

TfmdDL.to(device)

Put self and its transforms state on device

数据加载器 -

@docs
class DataLoaders(GetAttr):
    "Basic wrapper around several `DataLoader`s."
    _default='train'
    def __init__(self, 
        *loaders, # `DataLoader`对象用于包装
        path:str|Path='.', # 存储导出对象的路径
        device=None # 用于放置 `DataLoaders` 的设备
    ):
        self.loaders,self.path = list(loaders),Path(path)
        if device is not None and (loaders!=() and hasattr(loaders[0],'to')): self.device = device

    def __getitem__(self, i): return self.loaders[i]
    def __len__(self): return len(self.loaders)
    def new_empty(self):
        loaders = [dl.new(dl.dataset.new_empty()) for dl in self.loaders]
        return type(self)(*loaders, path=self.path, device=self.device)

    def _set(i, self, v): self.loaders[i] = v
    train   ,valid    = add_props(lambda i,x: x[i], _set)
    train_ds,valid_ds = add_props(lambda i,x: x[i].dataset)

    @property
    def device(self): return self._device

    @device.setter
    def device(self, 
        d # 用于放置 `DataLoaders` 的设备
    ):
        for dl in self.loaders: dl.to(d)
        self._device = d

    def to(self, 
        device # 用于放置 `DataLoaders` 的设备
    ):
        self.device = device
        return self
            
    def _add_tfms(self, tfms, event, dl_idx):
        "Adds `tfms` to `event` on `dl`"
        if(isinstance(dl_idx,str)): dl_idx = 0 if(dl_idx=='train') else 1
        dl_tfms = getattr(self[dl_idx], event)
        apply(dl_tfms.add, tfms)
        
    def add_tfms(self,
        tfms, # 要应用的`Transform`列表或`Pipeline`
        event, # 何时运行 `Transform`。`TfmdDL` 中提到的事件
        loaders=None # 要添加 `tfms` 的 `DataLoader` 对象列表
    ):
        "Adds `tfms` to `events` on `loaders`"
        if(loaders is None): loaders=range(len(self.loaders))
        if not is_listy(loaders): loaders = listify(loaders)
        for loader in loaders:
            self._add_tfms(tfms,event,loader)      

    def cuda(self): return self.to(device=default_device())
    def cpu(self):  return self.to(device=torch.device('cpu'))

    @classmethod
    def from_dsets(cls, 
        *ds, # `数据集`对象
        path:str|Path='.', # 用于放置 `DataLoaders` 的路径
        bs:int=64, # 批量大小
        device=None, # 用于放置 `DataLoaders` 的设备
        dl_type=TfmdDL, # `DataLoader` 的类型
        **kwargs
    ):
        default = (True,) + (False,) * (len(ds)-1)
        defaults = {'shuffle': default, 'drop_last': default}
        tfms = {k:tuple(Pipeline(kwargs[k]) for i in range_of(ds)) for k in _batch_tfms if k in kwargs}
        kwargs = merge(defaults, {k: tuplify(v, match=ds) for k,v in kwargs.items() if k not in _batch_tfms}, tfms)
        kwargs = [{k: v[i] for k,v in kwargs.items()} for i in range_of(ds)]
        return cls(*[dl_type(d, bs=bs, **k) for d,k in zip(ds, kwargs)], path=path, device=device)

    @classmethod
    def from_dblock(cls, 
        dblock, # `DataBlock` 对象
        source, # 数据来源。可以是文件的 `路径`
        path:str|Path='.', # 用于放置 `DataLoaders` 的路径
        bs:int=64, # 批量大小
        val_bs:int=None, # 批量大小 for validation `DataLoader`
        shuffle:bool=True, # 是否打乱数据
        device=None, # 用于放置 `DataLoaders` 的设备
        **kwargs
    ):
        return dblock.dataloaders(source, path=path, bs=bs, val_bs=val_bs, shuffle=shuffle, device=device, **kwargs)

    _docs=dict(__getitem__="Retrieve `DataLoader` at `i` (`0` is training, `1` is validation)",
               train="Training `DataLoader`",
               valid="Validation `DataLoader`",
               train_ds="Training `Dataset`",
               valid_ds="Validation `Dataset`",
               to="Use `device`",
               add_tfms="Add `tfms` to `loaders` for `event",
               cuda="Use accelerator if available",
               cpu="Use the cpu",
               new_empty="Create a new empty version of `self` with the same transforms",
               from_dblock="Create a dataloaders from a given `dblock`")
dls = DataLoaders(tdl,tdl)
x = dls.train.one_batch()
x2 = first(tdl)
test_eq(x,x2)
x2 = dls.one_batch()
test_eq(x,x2)
#测试分配有效
dls.train = dls.train.new(bs=4)

可以通过Dataloaders.add_tfms向多个数据加载器添加多个变换。您可以通过名称列表指定数据加载器,例如dls.add_tfms(...,'valid',...),或通过索引dls.add_tfms(...,1,....)。默认情况下,变换会添加到所有数据加载器中。event是一个必需的参数,用于确定变换何时运行,有关事件的更多信息,请参考TfmdDLtfms是一个Transform列表,是必需的参数。

class _TestTfm(Transform):
    def encodes(self, o):  return torch.ones_like(o)
    def decodes(self, o):  return o
tdl1,tdl2 = TfmdDL(start, bs=4),TfmdDL(start, bs=4)
dls2 = DataLoaders(tdl1,tdl2)
dls2.add_tfms([_TestTfm()],'after_batch',['valid'])
dls2.add_tfms([_TestTfm()],'after_batch',[1])
dls2.train.after_batch,dls2.valid.after_batch,
(Pipeline: , Pipeline: _TestTfm -> _TestTfm)
test_eq(len(dls2.train.after_batch.fs),0)
test_eq(len(dls2.valid.after_batch.fs),2)
test_eq(next(iter(dls2.valid)),tensor([1,1,1,1]))
class _T(Transform):  
    def encodes(self, o):  return -o
class _T2(Transform): 
    def encodes(self, o):  return o/2

#测试变换应用于训练和验证数据加载器
dls_from_ds = DataLoaders.from_dsets([1,], [5,], bs=1, after_item=_T, after_batch=_T2)
b = first(dls_from_ds.train)
test_eq(b, tensor([-.5]))
b = first(dls_from_ds.valid)
test_eq(b, tensor([-2.5]))

方法

show_doc(DataLoaders.__getitem__)

DataLoaders.__getitem__[source]

DataLoaders.__getitem__(i)

Retrieve DataLoader at i (0 is training, 1 is validation)

x2
tensor([ 0, -1, -2, -3])
x2 = dls[0].one_batch()
test_eq(x,x2)
show_doc(DataLoaders.train, name="DataLoaders.train")

DataLoaders.train[source]

Training DataLoader

show_doc(DataLoaders.valid, name="DataLoaders.valid")

DataLoaders.valid[source]

Validation DataLoader

show_doc(DataLoaders.train_ds, name="DataLoaders.train_ds")

DataLoaders.train_ds[source]

Training Dataset

show_doc(DataLoaders.valid_ds, name="DataLoaders.valid_ds")

DataLoaders.valid_ds[source]

Validation Dataset

TfmdLists -

class FilteredBase:
    "Base class for lists with subsets"
    _dl_type,_dbunch_type = TfmdDL,DataLoaders
    def __init__(self, *args, dl_type=None, **kwargs):
        if dl_type is not None: self._dl_type = dl_type
        self.dataloaders = delegates(self._dl_type.__init__)(self.dataloaders)
        super().__init__(*args, **kwargs)

    @property
    def n_subsets(self): return len(self.splits)
    def _new(self, items, **kwargs): return super()._new(items, splits=self.splits, **kwargs)
    def subset(self): raise NotImplemented

    def dataloaders(self, 
        bs:int=64, # 批量大小
        shuffle_train:bool=None, # (已弃用,请使用 `shuffle`)打乱训练 `DataLoader`
        shuffle:bool=True, # 洗牌训练 `DataLoader`
        val_shuffle:bool=False, # 洗牌验证 `DataLoader`
        n:int=None, # 用于创建 `DataLoader` 的 `Datasets` 的大小
        path:str|Path='.', # 用于放置 `DataLoaders` 的路径
        dl_type:TfmdDL=None, # `DataLoader` 的类型
        dl_kwargs:list=None, # 传递给各个 `DataLoader` 的 kwargs 列表
        device:torch.device=None, # 用于放置 `DataLoaders` 的设备
        drop_last:bool=None, # 丢弃最后一个不完整的批次,默认为 `shuffle`
        val_bs:int=None, # 验证批次大小,默认为 `bs`
        **kwargs
    ) -> DataLoaders:
        if shuffle_train is not None: 
            shuffle=shuffle_train
            warnings.warn('`shuffle_train` is deprecated. Use `shuffle` instead.',DeprecationWarning)
        if device is None: device=default_device()
        if dl_kwargs is None: dl_kwargs = [{}] * self.n_subsets
        if dl_type is None: dl_type = self._dl_type
        if drop_last is None: drop_last = shuffle
        val_kwargs={k[4:]:v for k,v in kwargs.items() if k.startswith('val_')}
        def_kwargs = {'bs':bs,'shuffle':shuffle,'drop_last':drop_last,'n':n,'device':device}
        dl = dl_type(self.subset(0), **merge(kwargs,def_kwargs, dl_kwargs[0]))
        def_kwargs = {'bs':bs if val_bs is None else val_bs,'shuffle':val_shuffle,'n':None,'drop_last':False}
        dls = [dl] + [dl.new(self.subset(i), **merge(kwargs,def_kwargs,val_kwargs,dl_kwargs[i]))
                      for i in range(1, self.n_subsets)]
        return self._dbunch_type(*dls, path=path, device=device)    

FilteredBase.train,FilteredBase.valid = add_props(lambda i,x: x.subset(i))
show_doc(FilteredBase().dataloaders)

FilteredBase.dataloaders[source]

FilteredBase.dataloaders(bs=64, shuffle_train=None, shuffle=True, val_shuffle=False, n=None, path='.', dl_type=None, dl_kwargs=None, device=None, drop_last=None, val_bs=None, num_workers=None, verbose=False, do_setup=True, pin_memory=False, timeout=0, batch_size=None, indexed=None, persistent_workers=False, wif=None, before_iter=None, after_item=None, before_batch=None, after_batch=None, after_iter=None, create_batches=None, create_item=None, create_batch=None, retain=None, get_idxs=None, sample=None, shuffle_fn=None, do_batch=None)

class TfmdLists(FilteredBase, L, GetAttr):
    "A `Pipeline` of `tfms` applied to a collection of `items`"
    _default='tfms'
    def __init__(self, 
        items:list, # 需要应用 `Transform` 的项目
        tfms:MutableSequence|Pipeline, # 应用`Transform`(s) 或 `Pipeline`
        use_list:bool=None, # 在 `L` 中使用 `list`
        do_setup:bool=True, # 为 `Transform` 调用 `setup()`
        split_idx:int=None, # 对训练集或验证集应用`Transform`。`0`表示训练集,`1`表示验证集。
        train_setup:bool=True, # 仅在训练 `DataLoader` 上应用 `Transform`(s)
        splits:list=None, # 训练集和验证集的指标
        types=None, # `items`中的数据类型
        verbose:bool=False, # 打印详细输出
        dl_type:TfmdDL=None # `DataLoader` 的类型
    ):
        super().__init__(items, use_list=use_list)
        if dl_type is not None: self._dl_type = dl_type
        self.splits = L([slice(None),[]] if splits is None else splits).map(mask2idxs)
        if isinstance(tfms,TfmdLists): tfms = tfms.tfms
        if isinstance(tfms,Pipeline): do_setup=False
        self.tfms = Pipeline(tfms, split_idx=split_idx)
        store_attr('types,split_idx')
        if do_setup:
            pv(f"Setting up {self.tfms}", verbose)
            self.setup(train_setup=train_setup)

    def _new(self, items, split_idx=None, **kwargs):
        split_idx = ifnone(split_idx,self.split_idx)
        try: return super()._new(items, tfms=self.tfms, do_setup=False, types=self.types, split_idx=split_idx, **kwargs)
        except IndexError as e:
            e.args = [f"Tried to grab subset {i} in the Dataset, but it contained no items.\n\t{e.args[0]}"]
            raise
    def subset(self, i): return self._new(self._get(self.splits[i]), split_idx=i)
    def _after_item(self, o): return self.tfms(o)
    def __repr__(self): return f"{self.__class__.__name__}: {self.items}\ntfms - {self.tfms.fs}"
    def __iter__(self): return (self[i] for i in range(len(self)))
    def show(self, o, **kwargs): return self.tfms.show(o, **kwargs)
    def decode(self, o, **kwargs): return self.tfms.decode(o, **kwargs)
    def __call__(self, o, **kwargs): return self.tfms.__call__(o, **kwargs)
    def overlapping_splits(self): return L(Counter(self.splits.concat()).values()).filter(gt(1))
    def new_empty(self): return self._new([])

    def setup(self, 
        train_setup:bool=True # 仅在训练 `DataLoader` 上应用 `Transform`(s)
    ):
        self.tfms.setup(self, train_setup)
        if len(self) != 0:
            x = super().__getitem__(0) if self.splits is None else super().__getitem__(self.splits[0])[0]
            self.types = []
            for f in self.tfms.fs:
                self.types.append(getattr(f, 'input_types', type(x)))
                x = f(x)
            self.types.append(type(x))
        types = L(t if is_listy(t) else [t] for t in self.types).concat().unique()
        self.pretty_types = '\n'.join([f'  - {t}' for t in types])

    def infer_idx(self, x):
        # 待办事项:检查我们是否真的需要这个,或者是否可以简化。
        idx = 0
        for t in self.types:
            if isinstance(x, t): break
            idx += 1
        types = L(t if is_listy(t) else [t] for t in self.types).concat().unique()
        pretty_types = '\n'.join([f'  - {t}' for t in types])
        assert idx < len(self.types), f"Expected an input of type in \n{pretty_types}\n but got {type(x)}"
        return idx

    def infer(self, x):
        return compose_tfms(x, tfms=self.tfms.fs[self.infer_idx(x):], split_idx=self.split_idx)

    def __getitem__(self, idx):
        res = super().__getitem__(idx)
        if self._after_item is None: return res
        return self._after_item(res) if is_indexer(idx) else res.map(self._after_item)
add_docs(TfmdLists,
         setup="Transform setup with self",
         decode="From `Pipeline`",
         show="From `Pipeline`",
         overlapping_splits="All splits that are in more than one split",
         subset="New `TfmdLists` with same tfms that only includes items in `i`th split",
         infer_idx="Finds the index where `self.tfms` can be applied to `x`, depending on the type of `x`",
         infer="Apply `self.tfms` to `x` starting at the right tfm depending on the type of `x`",
         new_empty="A new version of `self` but with no items")

::: {#cell-55 .cell 0=‘输’ 1=‘出’}

def decode_at(o, idx):
    "Decoded item at `idx`"
    return o.decode(o[idx])

:::

::: {#cell-56 .cell 0=‘输’ 1=‘出’}

def show_at(o, idx, **kwargs):
    "Show item at `idx`",
    return o.show(o[idx], **kwargs)

:::

TfmdLists 结合了一组对象与一个 Pipelinetfms 可以是一个 Pipeline 或一个转换列表,在这种情况下,它会将它们包装在一个 Pipeline 中。use_list 作为参数传递给 L,而 itemssplit_idx 则传递给 Pipeline 的每个转换。do_setup 指示在初始化期间是否应该调用 Pipeline.setup 方法。

class _IntFloatTfm(Transform):
    def encodes(self, o):  return TitledInt(o)
    def decodes(self, o):  return TitledFloat(o)
int2f_tfm=_IntFloatTfm()

def _neg(o): return -o
neg_tfm = Transform(_neg, _neg)
items = L([1.,2.,3.]); tfms = [neg_tfm, int2f_tfm]
tl = TfmdLists(items, tfms=tfms)
test_eq_type(tl[0], TitledInt(-1))
test_eq_type(tl[1], TitledInt(-2))
test_eq_type(tl.decode(tl[2]), TitledFloat(3.))
test_stdout(lambda: show_at(tl, 2), '-3')
test_eq(tl.types, [float, float, TitledInt])
tl
TfmdLists: [1.0, 2.0, 3.0]
tfms - [_neg:
encodes: (object,object) -> _negdecodes: (object,object) -> _neg, _IntFloatTfm:
encodes: (object,object) -> encodes
decodes: (object,object) -> decodes
]
# 向 TfmdLists 添加拆分
splits = [[0,2],[1]]
tl = TfmdLists(items, tfms=tfms, splits=splits)
test_eq(tl.n_subsets, 2)
test_eq(tl.train, tl.subset(0))
test_eq(tl.valid, tl.subset(1))
test_eq(tl.train.items, items[splits[0]])
test_eq(tl.valid.items, items[splits[1]])
test_eq(tl.train.tfms.split_idx, 0)
test_eq(tl.valid.tfms.split_idx, 1)
test_eq(tl.train.new_empty().split_idx, 0)
test_eq(tl.valid.new_empty().split_idx, 1)
test_eq_type(tl.splits, L(splits))
assert not tl.overlapping_splits()
df = pd.DataFrame(dict(a=[1,2,3],b=[2,3,4]))
tl = TfmdLists(df, lambda o: o.a+1, splits=[[0],[1,2]])
test_eq(tl[1,2], [3,4])
tr = tl.subset(0)
test_eq(tr[:], [2])
val = tl.subset(1)
test_eq(val[:], [3,4])
class _B(Transform):
    def __init__(self): self.m = 0
    def encodes(self, o): return o+self.m
    def decodes(self, o): return o-self.m
    def setups(self, items): 
        print(items)
        self.m = tensor(items).float().mean().item()

# 测试设置,更新 `self.m`
tl = TfmdLists(items, _B())
test_eq(tl.m, 2)
TfmdLists: [1.0, 2.0, 3.0]
tfms - []

下面是我们如何使用 TfmdLists.setup 来实现一个简单的类别列表,从一个模拟文件列表中获取标签:

class _Cat(Transform):
    order = 1
    def encodes(self, o):    return int(self.o2i[o])
    def decodes(self, o):    return TitledStr(self.vocab[o])
    def setups(self, items): self.vocab,self.o2i = uniqueify(L(items), sort=True, bidir=True)
tcat = _Cat()

def _lbl(o): return TitledStr(o.split('_')[0])

# 确保变换按 `order` 排序,并且首先调用 `_lbl`。
fns = ['dog_0.jpg','cat_0.jpg','cat_2.jpg','cat_1.jpg','dog_1.jpg']
tl = TfmdLists(fns, [tcat,_lbl])
exp_voc = ['cat','dog']
test_eq(tcat.vocab, exp_voc)
test_eq(tl.tfms.vocab, exp_voc)
test_eq(tl.vocab, exp_voc)
test_eq(tl, (1,0,0,0,1))
test_eq([tl.decode(o) for o in tl], ('dog','cat','cat','cat','dog'))
#仅考虑训练集进行设置
tl = TfmdLists(fns, [tcat,_lbl], splits=[[0,4], [1,2,3]])
test_eq(tcat.vocab, ['dog'])
tfm = NegTfm(split_idx=1)
tds = TfmdLists(start, A())
tdl = TfmdDL(tds, after_batch=tfm, bs=4)
x = tdl.one_batch()
test_eq(x, torch.arange(4))
tds.split_idx = 1
x = tdl.one_batch()
test_eq(x, -torch.arange(4))
tds.split_idx = 0
x = tdl.one_batch()
test_eq(x, torch.arange(4))
tds = TfmdLists(start, A())
tdl = TfmdDL(tds, after_batch=NegTfm(), bs=4)
test_eq(tdl.dataset[0], start[0])
test_eq(len(tdl), (len(tds)-1)//4+1)
test_eq(tdl.bs, 4)
test_stdout(tdl.show_batch, '0\n1\n2\n3')
show_doc(TfmdLists.subset)

TfmdLists.subset[source]

TfmdLists.subset(i)

New TfmdLists with same tfms that only includes items in ith split

show_doc(TfmdLists.infer_idx)

TfmdLists.infer_idx[source]

TfmdLists.infer_idx(x)

Finds the index where self.tfms can be applied to x, depending on the type of x

show_doc(TfmdLists.infer)

TfmdLists.infer[source]

TfmdLists.infer(x)

Apply self.tfms to x starting at the right tfm depending on the type of x

def mult(x): return x*2
mult.order = 2

fns = ['dog_0.jpg','cat_0.jpg','cat_2.jpg','cat_1.jpg','dog_1.jpg']
tl = TfmdLists(fns, [_lbl,_Cat(),mult])

test_eq(tl.infer_idx('dog_45.jpg'), 0)
test_eq(tl.infer('dog_45.jpg'), 2)

test_eq(tl.infer_idx(4), 2)
test_eq(tl.infer(4), 8)

test_fail(lambda: tl.infer_idx(2.0))
test_fail(lambda: tl.infer(2.0))
#测试Transform上的input_types功能
cat = _Cat()
cat.input_types = (str, float)
tl = TfmdLists(fns, [_lbl,cat,mult])
test_eq(tl.infer_idx(2.0), 1)

#测试函数上的类型注释是否有效
def mult(x:int|float): return x*2
mult.order = 2
tl = TfmdLists(fns, [_lbl,_Cat(),mult])
test_eq(tl.infer_idx(2.0), 2)

数据集 -

@docs
@delegates(TfmdLists)
class Datasets(FilteredBase):
    "A dataset that creates a tuple from each `tfms`"
    def __init__(self, 
        items:list=None, # 创建`Datasets`的项目列表
        tfms:MutableSequence|Pipeline=None, # 要应用的`Transform`列表或`Pipeline`
        tls:TfmdLists=None, # 如果为None,则从`items`和`tfms`生成`self.tls`。
        n_inp:int=None, # `Datasets` 元组中应被视为输入部分的元素数量
        dl_type=None, # 当调用函数 `FilteredBase.dataloaders` 时使用的默认 `DataLoader` 类型
        **kwargs
    ):
        super().__init__(dl_type=dl_type)
        self.tls = L(tls if tls else [TfmdLists(items, t, **kwargs) for t in L(ifnone(tfms,[None]))])
        self.n_inp = ifnone(n_inp, max(1, len(self.tls)-1))

    def __getitem__(self, it):
        res = tuple([tl[it] for tl in self.tls])
        return res if is_indexer(it) else list(zip(*res))

    def __getattr__(self,k): return gather_attrs(self, k, 'tls')
    def __dir__(self): return super().__dir__() + gather_attr_names(self, 'tls')
    def __len__(self): return len(self.tls[0])
    def __iter__(self): return (self[i] for i in range(len(self)))
    def __repr__(self): return coll_repr(self)
    def decode(self, o, full=True): return tuple(tl.decode(o_, full=full) for o_,tl in zip(o,tuplify(self.tls, match=o)))
    def subset(self, i): return type(self)(tls=L(tl.subset(i) for tl in self.tls), n_inp=self.n_inp)
    def _new(self, items, *args, **kwargs): return super()._new(items, tfms=self.tfms, do_setup=False, **kwargs)
    def overlapping_splits(self): return self.tls[0].overlapping_splits()
    def new_empty(self): return type(self)(tls=[tl.new_empty() for tl in self.tls], n_inp=self.n_inp)
    @property
    def splits(self): return self.tls[0].splits
    @property
    def split_idx(self): return self.tls[0].tfms.split_idx
    @property
    def items(self): return self.tls[0].items
    @items.setter
    def items(self, v):
        for tl in self.tls: tl.items = v

    def show(self, o, ctx=None, **kwargs):
        for o_,tl in zip(o,self.tls): ctx = tl.show(o_, ctx=ctx, **kwargs)
        return ctx

    @contextmanager
    def set_split_idx(self, i):
        old_split_idx = self.split_idx
        for tl in self.tls: tl.tfms.split_idx = i
        try: yield self
        finally:
            for tl in self.tls: tl.tfms.split_idx = old_split_idx

    _docs=dict(
        decode="Compose `decode` of all `tuple_tfms` then all `tfms` on `i`",
        show="Show item `o` in `ctx`",
        dataloaders="Get a `DataLoaders`",
        overlapping_splits="All splits that are in more than one split",
        subset="New `Datasets` that only includes subset `i`",
        new_empty="Create a new empty version of the `self`, keeping only the transforms",
        set_split_idx="Contextmanager to use the same `Datasets` with another `split_idx`"
    )

一个Datasets通过对items(通常是输入和目标)应用tfms中的每个Transform(或Pipeline)创建一个元组。请注意,如果tfms仅包含一组tfms,则Datasets给出的项目将是一个元素的元组。

n_inp是应视为输入部分的元组中的元素数量,如果tfms由一组转换组成,则默认为1,否则默认为len(tfms)-1。在大多数情况下,Datasets输出的元组中的元素数量将为2(输入,目标),但有时也可能为3(例如在孪生网络或表格数据中),在这种情况下,我们需要能够确定输入何时结束以及目标何时开始。

items = [1,2,3,4]
dsets = Datasets(items, [[neg_tfm,int2f_tfm], [add(1)]])
t = dsets[0]
test_eq(t, (-1,2))
test_eq(dsets[0,1,2], [(-1,2),(-2,3),(-3,4)])
test_eq(dsets.n_inp, 1)
dsets.decode(t)
(1.0, 2)
class Norm(Transform):
    def encodes(self, o): return (o-self.m)/self.s
    def decodes(self, o): return (o*self.s)+self.m
    def setups(self, items):
        its = tensor(items).float()
        self.m,self.s = its.mean(),its.std()
items = [1,2,3,4]
nrm = Norm()
dsets = Datasets(items, [[neg_tfm,int2f_tfm], [neg_tfm,nrm]])

x,y = zip(*dsets)
test_close(tensor(y).mean(), 0)
test_close(tensor(y).std(), 1)
test_eq(x, (-1,-2,-3,-4,))
test_eq(nrm.m, -2.5)
test_stdout(lambda:show_at(dsets, 1), '-2')

test_eq(dsets.m, nrm.m)
test_eq(dsets.norm.m, nrm.m)
test_eq(dsets.train.norm.m, nrm.m)
#检查过滤器是否正确应用
class B(Transform):
    def encodes(self, x)->None:  return int(x+1)
    def decodes(self, x):        return TitledInt(x-1)
add1 = B(split_idx=1)

dsets = Datasets(items, [neg_tfm, [neg_tfm,int2f_tfm,add1]], splits=[[3],[0,1,2]])
test_eq(dsets[1], [-2,-2])
test_eq(dsets.valid[1], [-2,-1])
test_eq(dsets.valid[[1,1]], [[-2,-1], [-2,-1]])
test_eq(dsets.train[0], [-4,-4])
test_fns = ['dog_0.jpg','cat_0.jpg','cat_2.jpg','cat_1.jpg','kid_1.jpg']
tcat = _Cat()
dsets = Datasets(test_fns, [[tcat,_lbl]], splits=[[0,1,2], [3,4]])
test_eq(tcat.vocab, ['cat','dog'])
test_eq(dsets.train, [(1,),(0,),(0,)])
test_eq(dsets.valid[0], (0,))
test_stdout(lambda: show_at(dsets.train, 0), "dog")
inp = [0,1,2,3,4]
dsets = Datasets(inp, tfms=[None])

test_eq(*dsets[2], 2)          # 检索一个项目(默认选择子集0)
test_eq(dsets[1,2], [(1,),(2,)])    # 按索引检索两个项目
mask = [True,False,False,True,False]
test_eq(dsets[mask], [(0,),(3,)])   # 通过掩码检索两个项目
inp = pd.DataFrame(dict(a=[5,1,2,3,4]))
dsets = Datasets(inp, tfms=attrgetter('a')).subset(0)
test_eq(*dsets[2], 2)          # 检索一个项目(默认子集为0)
test_eq(dsets[1,2], [(1,),(2,)])    # 按索引检索两个项目
mask = [True,False,False,True,False]
test_eq(dsets[mask], [(5,),(3,)])   # 通过掩码检索两个项目
#测试 n_inp
inp = [0,1,2,3,4]
dsets = Datasets(inp, tfms=[None])
test_eq(dsets.n_inp, 1)
dsets = Datasets(inp, tfms=[[None],[None],[None]])
test_eq(dsets.n_inp, 2)
dsets = Datasets(inp, tfms=[[None],[None],[None]], n_inp=1)
test_eq(dsets.n_inp, 1)
# 分割可以是索引
dsets = Datasets(range(5), tfms=[None], splits=[tensor([0,2]), [1,3,4]])

test_eq(dsets.subset(0), [(0,),(2,)])
test_eq(dsets.train, [(0,),(2,)])       # 子集0被别名为`train`
test_eq(dsets.subset(1), [(1,),(3,),(4,)])
test_eq(dsets.valid, [(1,),(3,),(4,)])     # 子集1被别名为`valid`
test_eq(*dsets.valid[2], 4)
#assert '[(1,),(3,),(4,)]' in str(dsets) and '[(0,),(2,)]' in str(dsets)
dsets
(#5) [(0,),(1,),(2,),(3,),(4,)]
# 分割可以是布尔掩码(它们不必覆盖所有项目,但必须互不相交)
splits = [[False,True,True,False,True], [True,False,False,False,False]]
dsets = Datasets(range(5), tfms=[None], splits=splits)

test_eq(dsets.train, [(1,),(2,),(4,)])
test_eq(dsets.valid, [(0,)])
# 对所有项目应用变换
tfm = [[lambda x: x*2,lambda x: x+1]]
splits = [[1,2],[0,3,4]]
dsets = Datasets(range(5), tfm, splits=splits)
test_eq(dsets.train,[(3,),(5,)])
test_eq(dsets.valid,[(1,),(7,),(9,)])
test_eq(dsets.train[False,True], [(5,)])
# 仅转换子集1
class _Tfm(Transform):
    split_idx=1
    def encodes(self, x): return x*2
    def decodes(self, x): return TitledStr(x//2)
dsets = Datasets(range(5), [_Tfm()], splits=[[1,2],[0,3,4]])
test_eq(dsets.train,[(1,),(2,)])
test_eq(dsets.valid,[(0,),(6,),(8,)])
test_eq(dsets.train[False,True], [(2,)])
dsets
(#5) [(0,),(1,),(2,),(3,),(4,)]
#一个上下文管理器,用于更改split_idx并在训练集上应用验证转换
ds = dsets.train
with ds.set_split_idx(1):
    test_eq(ds,[(2,),(4,)])
test_eq(dsets.train,[(1,),(2,)])
#测试数据集的pickle文件
dsrc1 = pickle.loads(pickle.dumps(dsets))
test_eq(dsets.train, dsrc1.train)
test_eq(dsets.valid, dsrc1.valid)
dsets = Datasets(range(5), [_Tfm(),noop], splits=[[1,2],[0,3,4]])
test_eq(dsets.train,[(1,1),(2,2)])
test_eq(dsets.valid,[(0,0),(6,3),(8,4)])
start = torch.arange(0,50)
tds = Datasets(start, [A()])
tdl = TfmdDL(tds, after_item=NegTfm(), bs=4)
b = tdl.one_batch()
test_eq(tdl.decode_batch(b), ((0,),(1,),(2,),(3,)))
test_stdout(tdl.show_batch, "0\n1\n2\n3")
# 仅转换子集1
class _Tfm(Transform):
    split_idx=1
    def encodes(self, x): return x*2

dsets = Datasets(range(8), [None], splits=[[1,2,5,7],[0,3,4,6]])
# 仅转换子集1
class _Tfm(Transform):
    split_idx=1
    def encodes(self, x): return x*2

dsets = Datasets(range(8), [None], splits=[[1,2,5,7],[0,3,4,6]])
dls = dsets.dataloaders(bs=4, after_batch=_Tfm(), shuffle=False, device=torch.device('cpu'))
test_eq(dls.train, [(tensor([1,2,5, 7]),)])
test_eq(dls.valid, [(tensor([0,6,8,12]),)])
test_eq(dls.n_inp, 1)

方法

items = [1,2,3,4]
dsets = Datasets(items, [[neg_tfm,int2f_tfm]])
_dsrc = Datasets([1,2])
show_doc(_dsrc.dataloaders, name="Datasets.dataloaders")

Datasets.dataloaders[source]

Datasets.dataloaders(bs=64, shuffle_train=None, shuffle=True, val_shuffle=False, n=None, path='.', dl_type=None, dl_kwargs=None, device=None, drop_last=None, val_bs=None, num_workers=None, verbose=False, do_setup=True, pin_memory=False, timeout=0, batch_size=None, indexed=None, persistent_workers=False, wif=None, before_iter=None, after_item=None, before_batch=None, after_batch=None, after_iter=None, create_batches=None, create_item=None, create_batch=None, retain=None, get_idxs=None, sample=None, shuffle_fn=None, do_batch=None)

Get a DataLoaders

用于创建数据加载器。您可以在val_shuffle之前添加val_以覆盖验证集的功能。如果您需要处理多个数据加载器,dl_kwargs可以提供更细粒度的每个数据加载器控制。

show_doc(Datasets.decode)

Datasets.decode[source]

Datasets.decode(o, full=True)

Compose decode of all tuple_tfms then all tfms on i

test_eq(*dsets[0], -1)
test_eq(*dsets.decode((-1,)), 1)
show_doc(Datasets.show)

Datasets.show[source]

Datasets.show(o, ctx=None, **kwargs)

Show item o in ctx

test_stdout(lambda:dsets.show(dsets[1]), '-2')
show_doc(Datasets.new_empty)

Datasets.new_empty[source]

Datasets.new_empty()

Create a new empty version of the self, keeping only the transforms

items = [1,2,3,4]
nrm = Norm()
dsets = Datasets(items, [[neg_tfm,int2f_tfm], [neg_tfm]])
empty = dsets.new_empty()
test_eq(empty.items, [])
#测试它也适用于数据框
df = pd.DataFrame({'a':[1,2,3,4,5], 'b':[6,7,8,9,10]})
dsets = Datasets(df, [[attrgetter('a')], [attrgetter('b')]])
empty = dsets.new_empty()

添加用于推断的测试集

# 仅转换子集1
class _Tfm1(Transform):
    split_idx=0
    def encodes(self, x): return x*3

dsets = Datasets(range(8), [[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]])
test_eq(dsets.train, [(3,),(6,),(15,),(21,)])
test_eq(dsets.valid, [(0,),(6,),(8,),(12,)])
def test_set(
    dsets:Datasets|TfmdLists, # 从地图或可迭代样式数据集中加载数据
    test_items, # 测试数据集中的项目
    rm_tfms=None, # 从 `dsets` 验证集中应用的 `Transform` 的起始索引
    with_labels:bool=False # 测试项目是否包含标签
):
    "Create a test set from `test_items` using validation transforms of `dsets`"
    if isinstance(dsets, Datasets):
        tls = dsets.tls if with_labels else dsets.tls[:dsets.n_inp]
        test_tls = [tl._new(test_items, split_idx=1) for tl in tls]
        if rm_tfms is None: rm_tfms = [tl.infer_idx(get_first(test_items)) for tl in test_tls]
        else:               rm_tfms = tuplify(rm_tfms, match=test_tls)
        for i,j in enumerate(rm_tfms): test_tls[i].tfms.fs = test_tls[i].tfms.fs[j:]
        return Datasets(tls=test_tls)
    elif isinstance(dsets, TfmdLists):
        test_tl = dsets._new(test_items, split_idx=1)
        if rm_tfms is None: rm_tfms = dsets.infer_idx(get_first(test_items))
        test_tl.tfms.fs = test_tl.tfms.fs[rm_tfms:]
        return test_tl
    else: raise Exception(f"This method requires using the fastai library to assemble your data. Expected a `Datasets` or a `TfmdLists` but got {dsets.__class__.__name__}")
class _Tfm1(Transform):
    split_idx=0
    def encodes(self, x): return x*3

dsets = Datasets(range(8), [[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]])
test_eq(dsets.train, [(3,),(6,),(15,),(21,)])
test_eq(dsets.valid, [(0,),(6,),(8,),(12,)])

#验证集的转换已应用
tst = test_set(dsets, [1,2,3])
test_eq(tst, [(2,),(4,),(6,)])
#使用不同类型进行测试
tfm = _Tfm1()
tfm.split_idx,tfm.order = None,2
dsets = Datasets(['dog', 'cat', 'cat', 'dog'], [[_Cat(),tfm]])

#带有字符串
test_eq(test_set(dsets, ['dog', 'cat', 'cat']), [(3,), (0,), (0,)])
#使用整数
test_eq(test_set(dsets, [1,2]), [(3,), (6,)])
#测试不同输入长度的效果
dsets = Datasets(range(8), [[_Tfm(),_Tfm1()],[_Tfm(),_Tfm1()],[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]])
tst = test_set(dsets, [1,2,3])
test_eq(tst, [(2,2),(4,4),(6,6)])

dsets = Datasets(range(8), [[_Tfm(),_Tfm1()],[_Tfm(),_Tfm1()],[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]], n_inp=1)
tst = test_set(dsets, [1,2,3])
test_eq(tst, [(2,),(4,),(6,)])
#使用rm_tfms进行测试
dsets = Datasets(range(8), [[_Tfm(),_Tfm()]], splits=[[1,2,5,7],[0,3,4,6]])
tst = test_set(dsets, [1,2,3])
test_eq(tst, [(4,),(8,),(12,)])

dsets = Datasets(range(8), [[_Tfm(),_Tfm()]], splits=[[1,2,5,7],[0,3,4,6]])
tst = test_set(dsets, [1,2,3], rm_tfms=1)
test_eq(tst, [(2,),(4,),(6,)])

dsets = Datasets(range(8), [[_Tfm(),_Tfm()], [_Tfm(),_Tfm()]], splits=[[1,2,5,7],[0,3,4,6]], n_inp=2)
tst = test_set(dsets, [1,2,3], rm_tfms=(1,0))
test_eq(tst, [(2,4),(4,8),(6,12)])
@patch
@delegates(TfmdDL.__init__)
def test_dl(self:DataLoaders, 
    test_items, # 测试数据集中的项目
    rm_type_tfms=None, # 从`dsets`中的验证集应用的`Transform`(s)的起始索引
    with_labels:bool=False, # 测试项目是否包含标签
    **kwargs
):
    "Create a test dataloader from `test_items` using validation transforms of `dls`"
    test_ds = test_set(self.valid_ds, test_items, rm_tfms=rm_type_tfms, with_labels=with_labels
                      ) if isinstance(self.valid_ds, (Datasets, TfmdLists)) else test_items
    return self.valid.new(test_ds, **kwargs)
dsets = Datasets(range(8), [[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]])
dls = dsets.dataloaders(bs=4, device=torch.device('cpu'))
dsets = Datasets(range(8), [[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]])
dls = dsets.dataloaders(bs=4, device=torch.device('cpu'))
tst_dl = dls.test_dl([2,3,4,5])
test_eq(tst_dl._n_inp, 1)
test_eq(list(tst_dl), [(tensor([ 4,  6,  8, 10]),)])
#测试你可以改变变换
tst_dl = dls.test_dl([2,3,4,5], after_item=add1)
test_eq(list(tst_dl), [(tensor([ 5,  7,  9, 11]),)])

导出 -

from nbdev import nbdev_export
nbdev_export()
Converted 00_torch_core.ipynb.
Converted 01_layers.ipynb.
Converted 01a_losses.ipynb.
Converted 02_data.load.ipynb.
Converted 03_data.core.ipynb.
Converted 04_data.external.ipynb.
Converted 05_data.transforms.ipynb.
Converted 06_data.block.ipynb.
Converted 07_vision.core.ipynb.
Converted 08_vision.data.ipynb.
Converted 09_vision.augment.ipynb.
Converted 09b_vision.utils.ipynb.
Converted 09c_vision.widgets.ipynb.
Converted 10_tutorial.pets.ipynb.
Converted 10b_tutorial.albumentations.ipynb.
Converted 11_vision.models.xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_callback.core.ipynb.
Converted 13a_learner.ipynb.
Converted 13b_metrics.ipynb.
Converted 14_callback.schedule.ipynb.
Converted 14a_callback.data.ipynb.
Converted 15_callback.hook.ipynb.
Converted 15a_vision.models.unet.ipynb.
Converted 16_callback.progress.ipynb.
Converted 17_callback.tracker.ipynb.
Converted 18_callback.fp16.ipynb.
Converted 18a_callback.training.ipynb.
Converted 18b_callback.preds.ipynb.
Converted 19_callback.mixup.ipynb.
Converted 20_interpret.ipynb.
Converted 20a_distributed.ipynb.
Converted 21_vision.learner.ipynb.
Converted 22_tutorial.imagenette.ipynb.
Converted 23_tutorial.vision.ipynb.
Converted 24_tutorial.image_sequence.ipynb.
Converted 24_tutorial.siamese.ipynb.
Converted 24_vision.gan.ipynb.
Converted 30_text.core.ipynb.
Converted 31_text.data.ipynb.
Converted 32_text.models.awdlstm.ipynb.
Converted 33_text.models.core.ipynb.
Converted 34_callback.rnn.ipynb.
Converted 35_tutorial.wikitext.ipynb.
Converted 37_text.learner.ipynb.
Converted 38_tutorial.text.ipynb.
Converted 39_tutorial.transformers.ipynb.
Converted 40_tabular.core.ipynb.
Converted 41_tabular.data.ipynb.
Converted 42_tabular.model.ipynb.
Converted 43_tabular.learner.ipynb.
Converted 44_tutorial.tabular.ipynb.
Converted 45_collab.ipynb.
Converted 46_tutorial.collab.ipynb.
Converted 50_tutorial.datablock.ipynb.
Converted 60_medical.imaging.ipynb.
Converted 61_tutorial.medical_imaging.ipynb.
Converted 65_medical.text.ipynb.
Converted 70_callback.wandb.ipynb.
Converted 71_callback.tensorboard.ipynb.
Converted 72_callback.neptune.ipynb.
Converted 73_callback.captum.ipynb.
Converted 74_callback.azureml.ipynb.
Converted 97_test_utils.ipynb.
Converted 99_pytorch_doc.ipynb.
Converted dev-setup.ipynb.
Converted app_examples.ipynb.
Converted camvid.ipynb.
Converted migrating_catalyst.ipynb.
Converted migrating_ignite.ipynb.
Converted migrating_lightning.ipynb.
Converted migrating_pytorch.ipynb.
Converted ulmfit.ipynb.
Converted index.ipynb.
Converted quick_start.ipynb.
Converted tutorial.ipynb.