表格学习器

! [ -e /content ] && pip install -Uqq fastai  # 在Colab上升级fastai
from __future__ import annotations
from fastai.basics import *
from fastai.tabular.core import *
from fastai.tabular.model import *
from fastai.tabular.data import *
from nbdev.showdoc import *

用于立即获取一个可以训练表格数据的 Learner 函数

您可能想要在此模块中使用的主要功能是 tabular_learner。它将自动为您的数据创建一个适合的 TabularModel 并推断出正确的损失函数。有关上下文中使用示例,请参见 表格教程

主要功能

class TabularLearner(Learner):
    "`Learner` for tabular data"
    def predict(self, 
        row:pd.Series, # 待预测的特征
    ):
        "Predict on a single sample"
        dl = self.dls.test_dl(row.to_frame().T)
        dl.dataset.conts = dl.dataset.conts.astype(np.float32)
        inp,preds,_,dec_preds = self.get_preds(dl=dl, with_input=True, with_decoded=True)
        b = (*tuplify(inp),*tuplify(dec_preds))
        full_dec = self.dls.decode(b)
        return full_dec,dec_preds[0],preds[0]
show_doc(TabularLearner, title_level=3)

class TabularLearner[source]

TabularLearner(dls, model, loss_func=None, opt_func=Adam, lr=0.001, splitter=trainable_params, cbs=None, metrics=None, path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95, 0.85, 0.95)) :: Learner

Learner for tabular data

它的工作方式与普通的 Learner 完全相同,唯一的区别是它实现了一个特定于处理数据行的 predict 方法。

@delegates(Learner.__init__)
def tabular_learner(
        dls:TabularDataLoaders,
        layers:list=None, # `LinBnDrop`生成的各层尺寸
        emb_szs:list=None, # 所有分类特征的 `n_unique, embedding_size` 元组
        config:dict=None, # 来自 `tabular_config` 的 TabularModel 配置参数
        n_out:int=None, # 模型的最终输出尺寸
        y_range:Tuple=None, # 低和高用于最终的S形函数
        **kwargs
):
    "Get a `Learner` using `dls`, with `metrics`, including a `TabularModel` created using the remaining params."
    if config is None: config = tabular_config()
    if layers is None: layers = [200,100]
    to = dls.train_ds
    emb_szs = get_emb_sz(dls.train_ds, {} if emb_szs is None else emb_szs)
    if n_out is None: n_out = get_c(dls)
    assert n_out, "`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`"
    if y_range is None and 'y_range' in config: y_range = config.pop('y_range')
    model = TabularModel(emb_szs, len(dls.cont_names), n_out, layers, y_range=y_range, **config)
    return TabularLearner(dls, model, **kwargs)

如果您的数据是使用fastai构建的,您可能不需要向emb_szs传递任何内容,除非您想更改库的默认值(由get_emb_sz生成),n_out也是如此,它应该会被自动推断。layers的默认值将为[200,100],并与config一起传递给TabularModel

使用tabular_config创建一个config并自定义所使用的模型。由于此参数通常被使用,因此可以方便地访问y_range

所有其他参数都传递给Learner

path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
dls = TabularDataLoaders.from_df(df, path, procs=procs, cat_names=cat_names, cont_names=cont_names, 
                                 y_names="salary", valid_idx=list(range(800,1000)), bs=64)
learn = tabular_learner(dls)
show_doc(TabularLearner.predict)

TabularLearner.predict[source]

TabularLearner.predict(row)

Predict on a Pandas Series

我们可以将数据的单独一行传入我们的 TabularLearnerpredict 方法。它的输出与其他 predict 方法略有不同,因为这个方法将始终返回输入值:

row, clas, probs = learn.predict(df.iloc[0])
row.show()
workclass education marital-status occupation relationship race education-num_na age fnlwgt education-num salary
0 Private Assoc-acdm Married-civ-spouse #na# Wife White False 49.0 101320.001685 12.0 <50k
clas, probs
(tensor(0), tensor([0.5264, 0.4736]))
#测试是否传递了y轴范围
learn = tabular_learner(dls, y_range=(0,32))
assert isinstance(learn.model.layers[-1], SigmoidRange)
test_eq(learn.model.layers[-1].low, 0)
test_eq(learn.model.layers[-1].high, 32)

learn = tabular_learner(dls, config = tabular_config(y_range=(0,32)))
assert isinstance(learn.model.layers[-1], SigmoidRange)
test_eq(learn.model.layers[-1].low, 0)
test_eq(learn.model.layers[-1].high, 32)
@typedispatch
def show_results(x:Tabular, y:Tabular, samples, outs, ctxs=None, max_n=10, **kwargs):
    df = x.all_cols[:max_n]
    for n in x.y_names: df[n+'_pred'] = y[n][:max_n].values
    display_df(df)

导出 -

from nbdev import nbdev_export
nbdev_export()
Converted 00_torch_core.ipynb.
Converted 01_layers.ipynb.
Converted 02_data.load.ipynb.
Converted 03_data.core.ipynb.
Converted 04_data.external.ipynb.
Converted 05_data.transforms.ipynb.
Converted 06_data.block.ipynb.
Converted 07_vision.core.ipynb.
Converted 08_vision.data.ipynb.
Converted 09_vision.augment.ipynb.
Converted 09b_vision.utils.ipynb.
Converted 09c_vision.widgets.ipynb.
Converted 10_tutorial.pets.ipynb.
Converted 11_vision.models.xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_callback.core.ipynb.
Converted 13a_learner.ipynb.
Converted 13b_metrics.ipynb.
Converted 14_callback.schedule.ipynb.
Converted 14a_callback.data.ipynb.
Converted 15_callback.hook.ipynb.
Converted 15a_vision.models.unet.ipynb.
Converted 16_callback.progress.ipynb.
Converted 17_callback.tracker.ipynb.
Converted 18_callback.fp16.ipynb.
Converted 18a_callback.training.ipynb.
Converted 19_callback.mixup.ipynb.
Converted 20_interpret.ipynb.
Converted 20a_distributed.ipynb.
Converted 21_vision.learner.ipynb.
Converted 22_tutorial.imagenette.ipynb.
Converted 23_tutorial.vision.ipynb.
Converted 24_tutorial.siamese.ipynb.
Converted 24_vision.gan.ipynb.
Converted 30_text.core.ipynb.
Converted 31_text.data.ipynb.
Converted 32_text.models.awdlstm.ipynb.
Converted 33_text.models.core.ipynb.
Converted 34_callback.rnn.ipynb.
Converted 35_tutorial.wikitext.ipynb.
Converted 36_text.models.qrnn.ipynb.
Converted 37_text.learner.ipynb.
Converted 38_tutorial.text.ipynb.
Converted 40_tabular.core.ipynb.
Converted 41_tabular.data.ipynb.
Converted 42_tabular.model.ipynb.
Converted 43_tabular.learner.ipynb.
Converted 44_tutorial.tabular.ipynb.
Converted 45_collab.ipynb.
Converted 46_tutorial.collab.ipynb.
Converted 50_tutorial.datablock.ipynb.
Converted 60_medical.imaging.ipynb.
Converted 61_tutorial.medical_imaging.ipynb.
Converted 65_medical.text.ipynb.
Converted 70_callback.wandb.ipynb.
Converted 71_callback.tensorboard.ipynb.
Converted 72_callback.neptune.ipynb.
Converted 73_callback.captum.ipynb.
Converted 74_callback.cutmix.ipynb.
Converted 97_test_utils.ipynb.
Converted 99_pytorch_doc.ipynb.
Converted index.ipynb.
Converted tutorial.ipynb.