数据回调

! [ -e /content ] && pip install -Uqq fastai  # 在Colab上升级fastai

与学习者数据一起工作的回调函数

from __future__ import annotations
from fastai.basics import *
from nbdev.showdoc import *
from fastai.test_utils import *
class CollectDataCallback(Callback):
    "Collect all batches, along with `pred` and `loss`, into `self.data`. Mainly for testing"
    def before_fit(self): self.data = L()
    def after_batch(self): 
        self.data.append(self.learn.to_detach((self.xb,self.yb,self.pred,self.loss)))
@delegates()
class WeightedDL(TfmdDL):
    "Weighted dataloader where `wgts` is used for the training set only"
    def __init__(self, dataset=None, bs=None, wgts=None, **kwargs):
        wgts = array([1.]*len(dataset) if wgts is None else wgts)
        self.wgts = wgts/wgts.sum()
        super().__init__(dataset=dataset, bs=bs, **kwargs)

    def get_idxs(self):
        if self.n==0: return []
        if not self.shuffle: return super().get_idxs()
        return list(np.random.choice(self.n, self.n, p=self.wgts))
@patch
@delegates(Datasets.dataloaders)
def weighted_dataloaders(self:Datasets, wgts, bs=64, **kwargs):
    "Create a weighted dataloader `WeightedDL` with `wgts` for the training set"
    xtra_kwargs = [{}] * (self.n_subsets-1)
    return self.dataloaders(bs=bs, dl_type=WeightedDL, dl_kwargs=({'wgts':wgts}, *xtra_kwargs), **kwargs)
lbls = np.random.randint(0, 2, size=(10)) # 数据集大小为10(训练集=8,验证集=2)
is_valid = lambda i: i >= 8
dblock = DataBlock(blocks=[CategoryBlock], 
    getters=[lambda i: lbls[i]], splitter=FuncSplitter(is_valid))
dset = dblock.datasets(list(range(10)))
item_tfms = [ToTensor()] 
wgts = range(8) # len(wgts) == 8
dls = dset.weighted_dataloaders(bs=1, wgts=wgts, after_item=item_tfms)
dls.show_batch() # 如果 len(wgts) 不等于 8,此操作将失败
1
n = 160
dsets = Datasets(torch.arange(n).float())
dls = dsets.weighted_dataloaders(wgts=range(n), bs=16)
learn = synth_learner(data=dls, cbs=CollectDataCallback)
learn.fit(1)
t = concat(*learn.collect_data.data.itemgot(0,0))
plt.hist(t.numpy());
[0, nan, None, '00:00']

@patch
@delegates(Datasets.weighted_dataloaders)
def weighted_dataloaders(self:DataBlock, source, wgts, bs=64, verbose:bool=False, **kwargs):
    "Create a weighted dataloader `WeightedDL` with `wgts` for the dataset"
    dss = self.datasets(source, verbose=verbose)
    if not hasattr(wgts, '__array__'): wgts = np.array(wgts)
    trn_wgts = wgts[dss.splits[0]]
    return dss.weighted_dataloaders(trn_wgts, bs=bs, after_batch=self.batch_tfms, after_item=self.item_tfms, **kwargs)
dls = dblock.weighted_dataloaders(list(range(10)), wgts, bs=1)
dls.show_batch()
0
@delegates()
class PartialDL(TfmdDL):
    "Select randomly partial quantity of data at each epoch"
    def __init__(self, dataset=None, bs=None, partial_n=None, **kwargs):
        super().__init__(dataset=dataset, bs=bs, **kwargs)
        self.partial_n = min(partial_n, self.n) if partial_n else None

    def get_idxs(self):
        if self.partial_n is None: return super().get_idxs()
        return list(np.random.choice(self.n, self.partial_n, replace=False))

    def __len__(self):
        if self.partial_n is None: return super().__len__()
        return self.partial_n//self.bs + (0 if self.drop_last or self.partial_n%self.bs==0 else 1)
@patch
@delegates(Datasets.dataloaders)
def partial_dataloaders(self:FilteredBase, partial_n, bs=64, **kwargs):
    "Create a partial dataloader `PartialDL` for the training set"
    xtra_kwargs = [{}] * (self.n_subsets-1)
    return self.dataloaders(bs=bs, dl_type=PartialDL, dl_kwargs=({'partial_n':partial_n}, *xtra_kwargs), **kwargs)
dls = dsets.partial_dataloaders(partial_n=32, bs=16)
assert len(dls[0])==2
for batch in dls[0]:
    assert len(batch[0])==16

导出 -

from nbdev import nbdev_export
nbdev_export()
Converted 00_torch_core.ipynb.
Converted 01_layers.ipynb.
Converted 01a_losses.ipynb.
Converted 02_data.load.ipynb.
Converted 03_data.core.ipynb.
Converted 04_data.external.ipynb.
Converted 05_data.transforms.ipynb.
Converted 06_data.block.ipynb.
Converted 07_vision.core.ipynb.
Converted 08_vision.data.ipynb.
Converted 09_vision.augment.ipynb.
Converted 09b_vision.utils.ipynb.
Converted 09c_vision.widgets.ipynb.
Converted 10_tutorial.pets.ipynb.
Converted 10b_tutorial.albumentations.ipynb.
Converted 11_vision.models.xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_callback.core.ipynb.
Converted 13a_learner.ipynb.
Converted 13b_metrics.ipynb.
Converted 14_callback.schedule.ipynb.
Converted 14a_callback.data.ipynb.
Converted 15_callback.hook.ipynb.
Converted 15a_vision.models.unet.ipynb.
Converted 16_callback.progress.ipynb.
Converted 17_callback.tracker.ipynb.
Converted 18_callback.fp16.ipynb.
Converted 18a_callback.training.ipynb.
Converted 18b_callback.preds.ipynb.
Converted 19_callback.mixup.ipynb.
Converted 20_interpret.ipynb.
Converted 20a_distributed.ipynb.
Converted 20b_tutorial.distributed.ipynb.
Converted 21_vision.learner.ipynb.
Converted 22_tutorial.imagenette.ipynb.
Converted 23_tutorial.vision.ipynb.
Converted 24_tutorial.image_sequence.ipynb.
Converted 24_tutorial.siamese.ipynb.
Converted 24_vision.gan.ipynb.
Converted 30_text.core.ipynb.
Converted 31_text.data.ipynb.
Converted 32_text.models.awdlstm.ipynb.
Converted 33_text.models.core.ipynb.
Converted 34_callback.rnn.ipynb.
Converted 35_tutorial.wikitext.ipynb.
Converted 37_text.learner.ipynb.
Converted 38_tutorial.text.ipynb.
Converted 39_tutorial.transformers.ipynb.
Converted 40_tabular.core.ipynb.
Converted 41_tabular.data.ipynb.
Converted 42_tabular.model.ipynb.
Converted 43_tabular.learner.ipynb.
Converted 44_tutorial.tabular.ipynb.
Converted 45_collab.ipynb.
Converted 46_tutorial.collab.ipynb.
Converted 50_tutorial.datablock.ipynb.
Converted 60_medical.imaging.ipynb.
Converted 61_tutorial.medical_imaging.ipynb.
Converted 65_medical.text.ipynb.
Converted 70_callback.wandb.ipynb.
Converted 70a_callback.tensorboard.ipynb.
Converted 70b_callback.neptune.ipynb.
Converted 70c_callback.captum.ipynb.
Converted 70d_callback.comet.ipynb.
Converted 74_huggingface.ipynb.
Converted 97_test_utils.ipynb.
Converted 99_pytorch_doc.ipynb.
Converted dev-setup.ipynb.
Converted app_examples.ipynb.
Converted camvid.ipynb.
Converted distributed_app_examples.ipynb.
Converted migrating_catalyst.ipynb.
Converted migrating_ignite.ipynb.
Converted migrating_lightning.ipynb.
Converted migrating_pytorch.ipynb.
Converted migrating_pytorch_verbose.ipynb.
Converted ulmfit.ipynb.
Converted index.ipynb.
Converted quick_start.ipynb.
Converted tutorial.ipynb.