表格数据

! [ -e /content ] && pip install -Uqq fastai  # 在Colab上升级fastai
from __future__ import annotations
from fastai.torch_basics import *
from fastai.data.all import *
from fastai.tabular.core import *
from nbdev.showdoc import *

辅助函数用于在表格应用程序中获取数据到 DataLoaders 和更高级的类 TabularDataLoaders

准备模型训练的数据的主要类是 TabularDataLoaders 和它的工厂方法。查看 表格教程 获得使用示例。

表格数据加载器 -

class TabularDataLoaders(DataLoaders):
    "Basic wrapper around several `DataLoader`s with factory methods for tabular data"
    @classmethod
    @delegates(Tabular.dataloaders, but=["dl_type", "dl_kwargs"])
    def from_df(cls, 
        df:pd.DataFrame,
        path:str|Path='.', # `df` 的位置,默认为当前工作目录
        procs:list=None, # `TabularProc` 列表
        cat_names:list=None, # 与分类变量相关的列名
        cont_names:list=None, # 与连续变量相关的列名
        y_names:list=None, # 因变量名称
        y_block:TransformBlock=None, # 用于目标的`TransformBlock`
        valid_idx:list=None, # 用于验证集的索引列表,默认为随机分割
        **kwargs
    ):
        "Create `TabularDataLoaders` from `df` in `path` using `procs`"
        if cat_names is None: cat_names = []
        if cont_names is None: cont_names = list(set(df)-set(L(cat_names))-set(L(y_names)))
        splits = RandomSplitter()(df) if valid_idx is None else IndexSplitter(valid_idx)(df)
        to = TabularPandas(df, procs, cat_names, cont_names, y_names, splits=splits, y_block=y_block)
        return to.dataloaders(path=path, **kwargs)

    @classmethod
    def from_csv(cls, 
        csv:str|Path|io.BufferedReader, # 训练数据的CSV文件
        skipinitialspace:bool=True, # 跳过分隔符后的空格
        **kwargs
    ):
        "Create `TabularDataLoaders` from `csv` file in `path` using `procs`"
        return cls.from_df(pd.read_csv(csv, skipinitialspace=skipinitialspace), **kwargs)

    @delegates(TabDataLoader.__init__)
    def test_dl(self, 
        test_items, # 创建新测试的条目 `TabDataLoader` 格式与训练数据相同
        rm_type_tfms=None, # 要从 `procs` 中移除的 `Transform` 数量
        process:bool=True, # 立即对 `test_items` 应用验证 `TabularProc`
        inplace:bool=False, # 如果为 `False`,则在内存中保留原始 `test_items` 的单独副本。
        **kwargs
    ):
        "Create test `TabDataLoader` from `test_items` using validation `procs`"
        to = self.train_ds.new(test_items, inplace=inplace)
        if process: to.process()
        return self.valid.new(to, **kwargs)

Tabular._dbunch_type = TabularDataLoaders
TabularDataLoaders.from_csv = delegates(to=TabularDataLoaders.from_df)(TabularDataLoaders.from_csv)

此类不应直接使用,而应优先选择其中一个工厂方法。所有这些工厂方法都接受以下作为参数:

  • cat_names:分类变量的名称
  • cont_names:连续变量的名称
  • y_names:因变量的名称
  • y_block:用于目标的 TransformBlock
  • valid_idx:用于验证集的索引(默认为随机拆分)
  • bs:批量大小
  • val_bs:验证 DataLoader 的批量大小(默认为 bs
  • shuffle_train:是否打乱训练 DataLoader
  • n:覆盖数据集中元素的数量
  • device:要使用的 PyTorch 设备(默认为 default_device()
show_doc(TabularDataLoaders.from_df)

TabularDataLoaders.from_df[source]

TabularDataLoaders.from_df(df:DataFrame, path:Path'>)='.', procs:list=None, cat_names:list=None, cont_names:list=None, y_names:list=None, y_block:TransformBlock=None, valid_idx:list=None, bs=64, shuffle_train=None, shuffle=True, val_shuffle=False, n=None, device=None, drop_last=None, val_bs=None)

Create TabularDataLoaders from df in path using procs

让我们来看一个包含成人数据集的示例:

path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv', skipinitialspace=True)
df.head()
age workclass fnlwgt education education-num marital-status occupation relationship race sex capital-gain capital-loss hours-per-week native-country salary
0 49 Private 101320 Assoc-acdm 12.0 Married-civ-spouse NaN Wife White Female 0 1902 40 United-States >=50k
1 44 Private 236746 Masters 14.0 Divorced Exec-managerial Not-in-family White Male 10520 0 45 United-States >=50k
2 38 Private 96185 HS-grad NaN Divorced NaN Unmarried Black Female 0 0 32 United-States <50k
3 38 Self-emp-inc 112847 Prof-school 15.0 Married-civ-spouse Prof-specialty Husband Asian-Pac-Islander Male 0 0 40 United-States >=50k
4 42 Self-emp-not-inc 82297 7th-8th NaN Married-civ-spouse Other-service Wife Black Female 0 0 50 United-States <50k
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
dls = TabularDataLoaders.from_df(df, path, procs=procs, cat_names=cat_names, cont_names=cont_names, 
                                 y_names="salary", valid_idx=list(range(800,1000)), bs=64)
dls.show_batch()
workclass education marital-status occupation relationship race education-num_na age fnlwgt education-num salary
0 Private HS-grad Married-civ-spouse Adm-clerical Husband White False 24.0 121312.998272 9.0 <50k
1 Private HS-grad Never-married Other-service Not-in-family White False 19.0 198320.000325 9.0 <50k
2 Private Bachelors Married-civ-spouse Sales Husband White False 66.0 169803.999308 13.0 >=50k
3 Private HS-grad Divorced Adm-clerical Unmarried White False 40.0 799280.980929 9.0 <50k
4 Local-gov 10th Never-married Other-service Own-child White False 18.0 55658.003629 6.0 <50k
5 Private HS-grad Never-married Handlers-cleaners Other-relative White False 30.0 375827.003847 9.0 <50k
6 Private Some-college Never-married Handlers-cleaners Own-child White False 20.0 173723.999335 10.0 <50k
7 ? Some-college Never-married ? Own-child White False 21.0 107800.997986 10.0 <50k
8 Private HS-grad Never-married Handlers-cleaners Own-child White False 19.0 263338.000072 9.0 <50k
9 Private Some-college Married-civ-spouse Tech-support Husband White False 35.0 194590.999986 10.0 <50k
show_doc(TabularDataLoaders.from_csv)

TabularDataLoaders.from_csv[source]

TabularDataLoaders.from_csv(csv:BufferedReader'>), skipinitialspace:bool=True, path:Path'>)='.', procs:list=None, cat_names:list=None, cont_names:list=None, y_names:list=None, y_block:TransformBlock=None, valid_idx:list=None, bs=64, shuffle_train=None, shuffle=True, val_shuffle=False, n=None, device=None, drop_last=None, val_bs=None)

Create TabularDataLoaders from csv file in path using procs

cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
dls = TabularDataLoaders.from_csv(path/'adult.csv', path=path, procs=procs, cat_names=cat_names, cont_names=cont_names, 
                                  y_names="salary", valid_idx=list(range(800,1000)), bs=64)
show_doc(TabularDataLoaders.test_dl)

TabularDataLoaders.test_dl[source]

TabularDataLoaders.test_dl(test_items, rm_type_tfms=None, process:bool=True, inplace:bool=False, bs=16, shuffle=False, after_batch=None, num_workers=0, verbose=False, do_setup=True, pin_memory=False, timeout=0, batch_size=None, drop_last=False, indexed=None, n=None, device=None, persistent_workers=False, wif=None, before_iter=None, after_item=None, before_batch=None, after_iter=None, create_batches=None, create_item=None, create_batch=None, retain=None, get_idxs=None, sample=None, shuffle_fn=None, do_batch=None)

Create test TabDataLoader from test_items using validation procs

外部结构化数据文件可能包含意外的空格,例如,在逗号后面。我们可以在adult.csv的第一行看到这一点,内容为"49, Private,101320, ...". 通常需要进行修剪。Pandas有一个方便的参数skipinitialspace,可以通过TabularDataLoaders.from_csv()来使用。否则,后续推断使用的类别标签(如workclass:Private)可能会错误地分类为0"#na#",如果训练标签被读取为" Private"。让我们测试这个功能。

test_data = {
    'age': [49], 
    'workclass': ['Private'], 
    'fnlwgt': [101320],
    'education': ['Assoc-acdm'], 
    'education-num': [12.0],
    'marital-status': ['Married-civ-spouse'], 
    'occupation': [''],
    'relationship': ['Wife'],
    'race': ['White'],
}
input = pd.DataFrame(test_data)
tdl = dls.test_dl(input)

test_ne(0, tdl.dataset.iloc[0]['workclass'])

导出 -

from nbdev import nbdev_export
nbdev_export()
Converted 00_torch_core.ipynb.
Converted 01_layers.ipynb.
Converted 01a_losses.ipynb.
Converted 02_data.load.ipynb.
Converted 03_data.core.ipynb.
Converted 04_data.external.ipynb.
Converted 05_data.transforms.ipynb.
Converted 06_data.block.ipynb.
Converted 07_vision.core.ipynb.
Converted 08_vision.data.ipynb.
Converted 09_vision.augment.ipynb.
Converted 09b_vision.utils.ipynb.
Converted 09c_vision.widgets.ipynb.
Converted 10_tutorial.pets.ipynb.
Converted 10b_tutorial.albumentations.ipynb.
Converted 11_vision.models.xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_callback.core.ipynb.
Converted 13a_learner.ipynb.
Converted 13b_metrics.ipynb.
Converted 14_callback.schedule.ipynb.
Converted 14a_callback.data.ipynb.
Converted 15_callback.hook.ipynb.
Converted 15a_vision.models.unet.ipynb.
Converted 16_callback.progress.ipynb.
Converted 17_callback.tracker.ipynb.
Converted 18_callback.fp16.ipynb.
Converted 18a_callback.training.ipynb.
Converted 18b_callback.preds.ipynb.
Converted 19_callback.mixup.ipynb.
Converted 20_interpret.ipynb.
Converted 20a_distributed.ipynb.
Converted 21_vision.learner.ipynb.
Converted 22_tutorial.imagenette.ipynb.
Converted 23_tutorial.vision.ipynb.
Converted 24_tutorial.image_sequence.ipynb.
Converted 24_tutorial.siamese.ipynb.
Converted 24_vision.gan.ipynb.
Converted 30_text.core.ipynb.
Converted 31_text.data.ipynb.
Converted 32_text.models.awdlstm.ipynb.
Converted 33_text.models.core.ipynb.
Converted 34_callback.rnn.ipynb.
Converted 35_tutorial.wikitext.ipynb.
Converted 37_text.learner.ipynb.
Converted 38_tutorial.text.ipynb.
Converted 39_tutorial.transformers.ipynb.
Converted 40_tabular.core.ipynb.
Converted 41_tabular.data.ipynb.
Converted 42_tabular.model.ipynb.
Converted 43_tabular.learner.ipynb.
Converted 44_tutorial.tabular.ipynb.
Converted 45_collab.ipynb.
Converted 46_tutorial.collab.ipynb.
Converted 50_tutorial.datablock.ipynb.
Converted 60_medical.imaging.ipynb.
Converted 61_tutorial.medical_imaging.ipynb.
Converted 65_medical.text.ipynb.
Converted 70_callback.wandb.ipynb.
Converted 71_callback.tensorboard.ipynb.
Converted 72_callback.neptune.ipynb.
Converted 73_callback.captum.ipynb.
Converted 74_callback.azureml.ipynb.
Converted 97_test_utils.ipynb.
Converted 99_pytorch_doc.ipynb.
Converted dev-setup.ipynb.
Converted app_examples.ipynb.
Converted camvid.ipynb.
Converted migrating_catalyst.ipynb.
Converted migrating_ignite.ipynb.
Converted migrating_lightning.ipynb.
Converted migrating_pytorch.ipynb.
Converted migrating_pytorch_verbose.ipynb.
Converted ulmfit.ipynb.
Converted index.ipynb.
Converted index_original.ipynb.
Converted quick_start.ipynb.
Converted tutorial.ipynb.