! [ -e /content ] && pip install -Uqq fastai # 在Colab上升级fastai
数据回调
与学习者数据一起工作的回调函数
from __future__ import annotations
from fastai.basics import *
from nbdev.showdoc import *
from fastai.test_utils import *
class CollectDataCallback(Callback):
"Collect all batches, along with `pred` and `loss`, into `self.data`. Mainly for testing"
def before_fit(self): self.data = L()
def after_batch(self):
self.data.append(self.learn.to_detach((self.xb,self.yb,self.pred,self.loss)))
@delegates()
class WeightedDL(TfmdDL):
"Weighted dataloader where `wgts` is used for the training set only"
def __init__(self, dataset=None, bs=None, wgts=None, **kwargs):
= array([1.]*len(dataset) if wgts is None else wgts)
wgts self.wgts = wgts/wgts.sum()
super().__init__(dataset=dataset, bs=bs, **kwargs)
def get_idxs(self):
if self.n==0: return []
if not self.shuffle: return super().get_idxs()
return list(np.random.choice(self.n, self.n, p=self.wgts))
@patch
@delegates(Datasets.dataloaders)
def weighted_dataloaders(self:Datasets, wgts, bs=64, **kwargs):
"Create a weighted dataloader `WeightedDL` with `wgts` for the training set"
= [{}] * (self.n_subsets-1)
xtra_kwargs return self.dataloaders(bs=bs, dl_type=WeightedDL, dl_kwargs=({'wgts':wgts}, *xtra_kwargs), **kwargs)
= np.random.randint(0, 2, size=(10)) # 数据集大小为10(训练集=8,验证集=2)
lbls = lambda i: i >= 8
is_valid = DataBlock(blocks=[CategoryBlock],
dblock =[lambda i: lbls[i]], splitter=FuncSplitter(is_valid))
getters= dblock.datasets(list(range(10)))
dset = [ToTensor()]
item_tfms = range(8) # len(wgts) == 8
wgts = dset.weighted_dataloaders(bs=1, wgts=wgts, after_item=item_tfms) dls
# 如果 len(wgts) 不等于 8,此操作将失败 dls.show_batch()
1
= 160
n = Datasets(torch.arange(n).float())
dsets = dsets.weighted_dataloaders(wgts=range(n), bs=16)
dls = synth_learner(data=dls, cbs=CollectDataCallback) learn
1)
learn.fit(= concat(*learn.collect_data.data.itemgot(0,0))
t ; plt.hist(t.numpy())
[0, nan, None, '00:00']
@patch
@delegates(Datasets.weighted_dataloaders)
def weighted_dataloaders(self:DataBlock, source, wgts, bs=64, verbose:bool=False, **kwargs):
"Create a weighted dataloader `WeightedDL` with `wgts` for the dataset"
= self.datasets(source, verbose=verbose)
dss if not hasattr(wgts, '__array__'): wgts = np.array(wgts)
= wgts[dss.splits[0]]
trn_wgts return dss.weighted_dataloaders(trn_wgts, bs=bs, after_batch=self.batch_tfms, after_item=self.item_tfms, **kwargs)
= dblock.weighted_dataloaders(list(range(10)), wgts, bs=1)
dls dls.show_batch()
0
@delegates()
class PartialDL(TfmdDL):
"Select randomly partial quantity of data at each epoch"
def __init__(self, dataset=None, bs=None, partial_n=None, **kwargs):
super().__init__(dataset=dataset, bs=bs, **kwargs)
self.partial_n = min(partial_n, self.n) if partial_n else None
def get_idxs(self):
if self.partial_n is None: return super().get_idxs()
return list(np.random.choice(self.n, self.partial_n, replace=False))
def __len__(self):
if self.partial_n is None: return super().__len__()
return self.partial_n//self.bs + (0 if self.drop_last or self.partial_n%self.bs==0 else 1)
@patch
@delegates(Datasets.dataloaders)
def partial_dataloaders(self:FilteredBase, partial_n, bs=64, **kwargs):
"Create a partial dataloader `PartialDL` for the training set"
= [{}] * (self.n_subsets-1)
xtra_kwargs return self.dataloaders(bs=bs, dl_type=PartialDL, dl_kwargs=({'partial_n':partial_n}, *xtra_kwargs), **kwargs)
= dsets.partial_dataloaders(partial_n=32, bs=16) dls
assert len(dls[0])==2
for batch in dls[0]:
assert len(batch[0])==16
导出 -
from nbdev import nbdev_export
nbdev_export()
Converted 00_torch_core.ipynb.
Converted 01_layers.ipynb.
Converted 01a_losses.ipynb.
Converted 02_data.load.ipynb.
Converted 03_data.core.ipynb.
Converted 04_data.external.ipynb.
Converted 05_data.transforms.ipynb.
Converted 06_data.block.ipynb.
Converted 07_vision.core.ipynb.
Converted 08_vision.data.ipynb.
Converted 09_vision.augment.ipynb.
Converted 09b_vision.utils.ipynb.
Converted 09c_vision.widgets.ipynb.
Converted 10_tutorial.pets.ipynb.
Converted 10b_tutorial.albumentations.ipynb.
Converted 11_vision.models.xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_callback.core.ipynb.
Converted 13a_learner.ipynb.
Converted 13b_metrics.ipynb.
Converted 14_callback.schedule.ipynb.
Converted 14a_callback.data.ipynb.
Converted 15_callback.hook.ipynb.
Converted 15a_vision.models.unet.ipynb.
Converted 16_callback.progress.ipynb.
Converted 17_callback.tracker.ipynb.
Converted 18_callback.fp16.ipynb.
Converted 18a_callback.training.ipynb.
Converted 18b_callback.preds.ipynb.
Converted 19_callback.mixup.ipynb.
Converted 20_interpret.ipynb.
Converted 20a_distributed.ipynb.
Converted 20b_tutorial.distributed.ipynb.
Converted 21_vision.learner.ipynb.
Converted 22_tutorial.imagenette.ipynb.
Converted 23_tutorial.vision.ipynb.
Converted 24_tutorial.image_sequence.ipynb.
Converted 24_tutorial.siamese.ipynb.
Converted 24_vision.gan.ipynb.
Converted 30_text.core.ipynb.
Converted 31_text.data.ipynb.
Converted 32_text.models.awdlstm.ipynb.
Converted 33_text.models.core.ipynb.
Converted 34_callback.rnn.ipynb.
Converted 35_tutorial.wikitext.ipynb.
Converted 37_text.learner.ipynb.
Converted 38_tutorial.text.ipynb.
Converted 39_tutorial.transformers.ipynb.
Converted 40_tabular.core.ipynb.
Converted 41_tabular.data.ipynb.
Converted 42_tabular.model.ipynb.
Converted 43_tabular.learner.ipynb.
Converted 44_tutorial.tabular.ipynb.
Converted 45_collab.ipynb.
Converted 46_tutorial.collab.ipynb.
Converted 50_tutorial.datablock.ipynb.
Converted 60_medical.imaging.ipynb.
Converted 61_tutorial.medical_imaging.ipynb.
Converted 65_medical.text.ipynb.
Converted 70_callback.wandb.ipynb.
Converted 70a_callback.tensorboard.ipynb.
Converted 70b_callback.neptune.ipynb.
Converted 70c_callback.captum.ipynb.
Converted 70d_callback.comet.ipynb.
Converted 74_huggingface.ipynb.
Converted 97_test_utils.ipynb.
Converted 99_pytorch_doc.ipynb.
Converted dev-setup.ipynb.
Converted app_examples.ipynb.
Converted camvid.ipynb.
Converted distributed_app_examples.ipynb.
Converted migrating_catalyst.ipynb.
Converted migrating_ignite.ipynb.
Converted migrating_lightning.ipynb.
Converted migrating_pytorch.ipynb.
Converted migrating_pytorch_verbose.ipynb.
Converted ulmfit.ipynb.
Converted index.ipynb.
Converted quick_start.ipynb.
Converted tutorial.ipynb.