! [ -e /content ] && pip install -Uqq fastai # 在Colab上升级fastai
表格学习器
from __future__ import annotations
from fastai.basics import *
from fastai.tabular.core import *
from fastai.tabular.model import *
from fastai.tabular.data import *
from nbdev.showdoc import *
用于立即获取一个可以训练表格数据的
Learner
函数
您可能想要在此模块中使用的主要功能是 tabular_learner
。它将自动为您的数据创建一个适合的 TabularModel
并推断出正确的损失函数。有关上下文中使用示例,请参见 表格教程。
主要功能
class TabularLearner(Learner):
"`Learner` for tabular data"
def predict(self,
# 待预测的特征
row:pd.Series,
):"Predict on a single sample"
= self.dls.test_dl(row.to_frame().T)
dl = dl.dataset.conts.astype(np.float32)
dl.dataset.conts = self.get_preds(dl=dl, with_input=True, with_decoded=True)
inp,preds,_,dec_preds = (*tuplify(inp),*tuplify(dec_preds))
b = self.dls.decode(b)
full_dec return full_dec,dec_preds[0],preds[0]
=3) show_doc(TabularLearner, title_level
class
TabularLearner
[source]
TabularLearner
(dls
,model
,loss_func
=None
,opt_func
=Adam
,lr
=0.001
,splitter
=trainable_params
,cbs
=None
,metrics
=None
,path
=None
,model_dir
='models'
,wd
=None
,wd_bn_bias
=False
,train_bn
=True
,moms
=(0.95, 0.85, 0.95)
) ::Learner
Learner
for tabular data
它的工作方式与普通的 Learner
完全相同,唯一的区别是它实现了一个特定于处理数据行的 predict
方法。
@delegates(Learner.__init__)
def tabular_learner(
dls:TabularDataLoaders,list=None, # `LinBnDrop`生成的各层尺寸
layers:list=None, # 所有分类特征的 `n_unique, embedding_size` 元组
emb_szs:dict=None, # 来自 `tabular_config` 的 TabularModel 配置参数
config:int=None, # 模型的最终输出尺寸
n_out:=None, # 低和高用于最终的S形函数
y_range:Tuple**kwargs
):"Get a `Learner` using `dls`, with `metrics`, including a `TabularModel` created using the remaining params."
if config is None: config = tabular_config()
if layers is None: layers = [200,100]
= dls.train_ds
to = get_emb_sz(dls.train_ds, {} if emb_szs is None else emb_szs)
emb_szs if n_out is None: n_out = get_c(dls)
assert n_out, "`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`"
if y_range is None and 'y_range' in config: y_range = config.pop('y_range')
= TabularModel(emb_szs, len(dls.cont_names), n_out, layers, y_range=y_range, **config)
model return TabularLearner(dls, model, **kwargs)
如果您的数据是使用fastai构建的,您可能不需要向emb_szs
传递任何内容,除非您想更改库的默认值(由get_emb_sz
生成),n_out
也是如此,它应该会被自动推断。layers
的默认值将为[200,100]
,并与config
一起传递给TabularModel
。
使用tabular_config
创建一个config
并自定义所使用的模型。由于此参数通常被使用,因此可以方便地访问y_range
。
所有其他参数都传递给Learner
。
= untar_data(URLs.ADULT_SAMPLE)
path = pd.read_csv(path/'adult.csv')
df = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cat_names = ['age', 'fnlwgt', 'education-num']
cont_names = [Categorify, FillMissing, Normalize]
procs = TabularDataLoaders.from_df(df, path, procs=procs, cat_names=cat_names, cont_names=cont_names,
dls ="salary", valid_idx=list(range(800,1000)), bs=64)
y_names= tabular_learner(dls) learn
show_doc(TabularLearner.predict)
我们可以将数据的单独一行传入我们的 TabularLearner
的 predict
方法。它的输出与其他 predict
方法略有不同,因为这个方法将始终返回输入值:
= learn.predict(df.iloc[0]) row, clas, probs
row.show()
workclass | education | marital-status | occupation | relationship | race | education-num_na | age | fnlwgt | education-num | salary | |
---|---|---|---|---|---|---|---|---|---|---|---|
0 | Private | Assoc-acdm | Married-civ-spouse | #na# | Wife | White | False | 49.0 | 101320.001685 | 12.0 | <50k |
clas, probs
(tensor(0), tensor([0.5264, 0.4736]))
#测试是否传递了y轴范围
= tabular_learner(dls, y_range=(0,32))
learn assert isinstance(learn.model.layers[-1], SigmoidRange)
-1].low, 0)
test_eq(learn.model.layers[-1].high, 32)
test_eq(learn.model.layers[
= tabular_learner(dls, config = tabular_config(y_range=(0,32)))
learn assert isinstance(learn.model.layers[-1], SigmoidRange)
-1].low, 0)
test_eq(learn.model.layers[-1].high, 32) test_eq(learn.model.layers[
@typedispatch
def show_results(x:Tabular, y:Tabular, samples, outs, ctxs=None, max_n=10, **kwargs):
= x.all_cols[:max_n]
df for n in x.y_names: df[n+'_pred'] = y[n][:max_n].values
display_df(df)
导出 -
from nbdev import nbdev_export
nbdev_export()
Converted 00_torch_core.ipynb.
Converted 01_layers.ipynb.
Converted 02_data.load.ipynb.
Converted 03_data.core.ipynb.
Converted 04_data.external.ipynb.
Converted 05_data.transforms.ipynb.
Converted 06_data.block.ipynb.
Converted 07_vision.core.ipynb.
Converted 08_vision.data.ipynb.
Converted 09_vision.augment.ipynb.
Converted 09b_vision.utils.ipynb.
Converted 09c_vision.widgets.ipynb.
Converted 10_tutorial.pets.ipynb.
Converted 11_vision.models.xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_callback.core.ipynb.
Converted 13a_learner.ipynb.
Converted 13b_metrics.ipynb.
Converted 14_callback.schedule.ipynb.
Converted 14a_callback.data.ipynb.
Converted 15_callback.hook.ipynb.
Converted 15a_vision.models.unet.ipynb.
Converted 16_callback.progress.ipynb.
Converted 17_callback.tracker.ipynb.
Converted 18_callback.fp16.ipynb.
Converted 18a_callback.training.ipynb.
Converted 19_callback.mixup.ipynb.
Converted 20_interpret.ipynb.
Converted 20a_distributed.ipynb.
Converted 21_vision.learner.ipynb.
Converted 22_tutorial.imagenette.ipynb.
Converted 23_tutorial.vision.ipynb.
Converted 24_tutorial.siamese.ipynb.
Converted 24_vision.gan.ipynb.
Converted 30_text.core.ipynb.
Converted 31_text.data.ipynb.
Converted 32_text.models.awdlstm.ipynb.
Converted 33_text.models.core.ipynb.
Converted 34_callback.rnn.ipynb.
Converted 35_tutorial.wikitext.ipynb.
Converted 36_text.models.qrnn.ipynb.
Converted 37_text.learner.ipynb.
Converted 38_tutorial.text.ipynb.
Converted 40_tabular.core.ipynb.
Converted 41_tabular.data.ipynb.
Converted 42_tabular.model.ipynb.
Converted 43_tabular.learner.ipynb.
Converted 44_tutorial.tabular.ipynb.
Converted 45_collab.ipynb.
Converted 46_tutorial.collab.ipynb.
Converted 50_tutorial.datablock.ipynb.
Converted 60_medical.imaging.ipynb.
Converted 61_tutorial.medical_imaging.ipynb.
Converted 65_medical.text.ipynb.
Converted 70_callback.wandb.ipynb.
Converted 71_callback.tensorboard.ipynb.
Converted 72_callback.neptune.ipynb.
Converted 73_callback.captum.ipynb.
Converted 74_callback.cutmix.ipynb.
Converted 97_test_utils.ipynb.
Converted 99_pytorch_doc.ipynb.
Converted index.ipynb.
Converted tutorial.ipynb.