Skip to content

搜索最佳架构和超参数

有时候(或经常)我们并不确切知道哪种架构最适合我们的数据。在人工智能领域,某种架构可能在一个数据集上表现最佳,而在另一个数据集上表现不佳。为了帮助找到最佳解决方案,本Notebook将使用PyTorch Tabular中的两个主要功能。其中一个是Sweep,它运行PyTorch Tabular中所有可用架构的默认超参数,以寻找可能最适合我们数据的架构。之后,我们将使用Tuner来搜索在Sweep中找到的最佳架构的最佳超参数。

import warnings
warnings.filterwarnings("ignore")

from sklearn.model_selection import train_test_split

from pytorch_tabular.utils import make_mixed_dataset
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig

数据

首先,我们创建一个合成数据集,该数据集混合了数值特征和分类特征,并具有多个分类目标。这意味着我们需要用相同的特征集来预测多个列。

data, cat_col_names, num_col_names = make_mixed_dataset(
    task="classification", n_samples=3000, n_features=7, n_categories=4
)

train, test = train_test_split(data, random_state=42)
train, valid = train_test_split(train, random_state=42)

常见配置

data_config = DataConfig(
    target=[
        "target"
    ],
    continuous_cols=num_col_names,
    categorical_cols=cat_col_names,
)
trainer_config = TrainerConfig(
    batch_size=32,
    max_epochs=50,
    early_stopping="valid_accuracy",
    early_stopping_mode="max",
    early_stopping_patience=3,
    checkpoints="valid_accuracy",
    load_best=True,
    progress_bar="none"
)
optimizer_config = OptimizerConfig()

模型搜索

https://pytorch-tabular.readthedocs.io/en/latest/apidocs_coreclasses/#pytorch_tabular.model_sweep

让我们训练所有可用的模型(“高内存”)。如果其中一些返回“OOM”,这意味着您当前的batch_size内存不足。您可以忽略该模型或在TrainerConfig中减少batch_size。

from pytorch_tabular import model_sweep
sweep_df, best_model = model_sweep(
                            task="classification",
                            train=train,
                            test=valid,
                            data_config=data_config,
                            optimizer_config=optimizer_config,
                            trainer_config=trainer_config,
                            model_list="high_memory",
                            verbose=False # 如果你想在每次试验中记录指标和参数,请将其设置为True。
                        )
Output()
2024-07-20 12:47:01,862 - {pytorch_tabular.models.node.node_model:73} - INFO - Data Aware Initialization of NODE   
using a forward pass with 2000 batch size....                                                                      


best_model.evaluate(test)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.8053333163261414     │
│         test_loss             0.44678735733032227    │
└───────────────────────────┴───────────────────────────┘
[{'test_loss': 0.44678735733032227, 'test_accuracy': 0.8053333163261414}]

在下面的表格中,我们可以看到我们的数据集的最佳模型(使用默认超参数)。但我们对此并不满意,因此在这种情况下,我们将选择前两个模型,并使用调优器寻找更好的超参数,以获得更好的结果。

备注:每次运行笔记本时,结果可能会有所不同,因此您可能会看到我们在下一部分中使用的不同顶级模型。

sweep_df.drop(columns=["params", "time_taken", "epochs"]).sort_values("test_accuracy", ascending=False).style.background_gradient(
    subset=["test_accuracy"], cmap="RdYlGn"
).background_gradient(subset=["time_taken_per_epoch", "test_loss"], cmap="RdYlGn_r")
  model # Params test_loss test_accuracy time_taken_per_epoch
1 CategoryEmbeddingModel 12 T 0.458506 0.797513 0.190966
3 FTTransformerModel 272 T 0.486184 0.770870 0.529126
4 GANDALFModel 8 T 0.562945 0.705151 0.341467
8 TabTransformerModel 272 T 0.547346 0.696270 0.470920
0 AutoIntModel 14 T 0.580009 0.689165 0.360073
5 GatedAdditiveTreeEnsembleModel 79 T 0.673274 0.660746 3.624957
2 DANetModel 431 T 0.692986 0.644760 2.104359
6 NODEModel 864 T 0.676671 0.626998 1.497243
7 TabNetModel 6 T 0.708919 0.538188 0.484836

模型调优器

https://pytorch-tabular.readthedocs.io/en/latest/apidocs_coreclasses/#pytorch_tabular.TabularModelTuner

太棒了!!现在我们知道了最佳模型,接下来选取前两个模型并调整它们的超参数,以期找到更好的结果。

from pytorch_tabular.models import (
    CategoryEmbeddingModelConfig,
    FTTransformerConfig
)   

我们可以使用两种主要策略: - 网格搜索(grid_search):搜索所有已定义的超参数,但请记住,您添加的每个新字段都会显著增加总训练时间。如果您配置了4个优化器、4层、2个激活函数和2个丢弃率,这意味着会进行64次(4 * 4 * 2 * 2)训练。 - 随机搜索(random_search):将随机获得每个已定义模型的“n_trials”超参数设置。这对于更快的训练很有用,但请记住,这不会测试所有超参数。

有关所有超参数选项的更多信息:https://pytorch-tabular.readthedocs.io/en/latest/apidocs_model/

有关超参数空间如何工作的更多信息:https://pytorch-tabular.readthedocs.io/en/latest/tutorials/10-Hyperparameter%20Tuning/#define-the-hyperparameter-space

让我们定义一些超参数。

PS:这个笔记本是为了示范这些函数,并不意味着这是值得尝试的最佳超参数。

search_space_category_embedding = {
    "optimizer_config__optimizer": ["Adam", "SGD"],
    "model_config__layers": ["128-64-32", "1024-512-256", "32-64-128", "256-512-1024"],
    "model_config__activation": ["ReLU", "LeakyReLU"],
    "model_config__embedding_dropout": [0.0, 0.2],
}
model_config_category_embedding = CategoryEmbeddingModelConfig(task="classification")
search_space_ft_transformer = {
    "optimizer_config__optimizer": ["Adam", "SGD"],
    "model_config__input_embed_dim": [32, 64],
    "model_config__num_attn_blocks": [3, 6, 8],
    "model_config__ff_hidden_multiplier": [4, 8],
    "model_config__transformer_activation": ["GEGLU", "LeakyReLU"],
    "model_config__embedding_dropout": [0.0, 0.2],
}
model_config_ft_transformer = FTTransformerConfig(task="classification")

让我们将所有搜索空间和模型配置添加到列表中。

重要 它们必须具有相同的顺序和相同的长度

search_spaces = [search_space_category_embedding, search_space_ft_transformer]
model_configs = [model_config_category_embedding, model_config_ft_transformer]
from pytorch_tabular.tabular_model_tuner import TabularModelTuner
tuner = TabularModelTuner(
    data_config=data_config,
    model_config=model_configs,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config
)
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    tuner_df = tuner.tune(
        train=train,
        validation=valid,
        search_space=search_spaces,
        strategy="grid_search",  # 随机搜索
        # n_trials=5,
        metric="accuracy",
        mode="max",
        progress_bar=True,
        verbose=False # 如果希望记录每个试验的指标和参数,请将True设为。
    )
Output()


太棒了!!!我们现在知道了针对我们的数据集最佳的架构和可能的超参数。也许结果还不够好,但至少可以减少选项。有了这些结果,我们将更清楚哪些是可以更好地探索的最佳超参数,以及哪些超参数不值得继续使用。

探索架构论文也是一个好主意,这样如果有可能,可以进一步指导你找到最佳的超参数。

tuner_df.trials_df.sort_values("accuracy", ascending=False).style.background_gradient(
    subset=["accuracy"], cmap="RdYlGn"
).background_gradient(subset=["loss"], cmap="RdYlGn_r")
  trial_id model model_config__activation model_config__embedding_dropout model_config__layers optimizer_config__optimizer loss accuracy model_config__ff_hidden_multiplier model_config__input_embed_dim model_config__num_attn_blocks model_config__transformer_activation
22 22 0-CategoryEmbeddingModelConfig LeakyReLU 0.000000 256-512-1024 Adam 0.339012 0.857904 nan nan nan nan
26 26 0-CategoryEmbeddingModelConfig LeakyReLU 0.200000 1024-512-256 Adam 0.375515 0.817052 nan nan nan nan
20 20 0-CategoryEmbeddingModelConfig LeakyReLU 0.000000 32-64-128 Adam 0.368664 0.815275 nan nan nan nan
2 2 0-CategoryEmbeddingModelConfig ReLU 0.000000 1024-512-256 Adam 0.407023 0.813499 nan nan nan nan
6 6 0-CategoryEmbeddingModelConfig ReLU 0.000000 256-512-1024 Adam 0.445294 0.811723 nan nan nan nan
10 10 0-CategoryEmbeddingModelConfig ReLU 0.200000 1024-512-256 Adam 0.446737 0.811723 nan nan nan nan
18 18 0-CategoryEmbeddingModelConfig LeakyReLU 0.000000 1024-512-256 Adam 0.444420 0.808170 nan nan nan nan
30 30 0-CategoryEmbeddingModelConfig LeakyReLU 0.200000 256-512-1024 Adam 0.398530 0.797513 nan nan nan nan
14 14 0-CategoryEmbeddingModelConfig ReLU 0.200000 256-512-1024 Adam 0.455243 0.781528 nan nan nan nan
72 40 1-FTTransformerConfig nan 0.000000 nan Adam 0.445089 0.779751 8.000000 64.000000 6.000000 GEGLU
8 8 0-CategoryEmbeddingModelConfig ReLU 0.200000 128-64-32 Adam 0.486341 0.776199 nan nan nan nan
16 16 0-CategoryEmbeddingModelConfig LeakyReLU 0.000000 128-64-32 Adam 0.458817 0.776199 nan nan nan nan
116 84 1-FTTransformerConfig nan 0.200000 nan Adam 0.471312 0.776199 8.000000 64.000000 3.000000 GEGLU
62 30 1-FTTransformerConfig nan 0.000000 nan Adam 0.475959 0.774423 8.000000 32.000000 6.000000 LeakyReLU
36 4 1-FTTransformerConfig nan 0.000000 nan Adam 0.506062 0.772647 4.000000 32.000000 6.000000 GEGLU
28 28 0-CategoryEmbeddingModelConfig LeakyReLU 0.200000 32-64-128 Adam 0.503373 0.769094 nan nan nan nan
0 0 0-CategoryEmbeddingModelConfig ReLU 0.000000 128-64-32 Adam 0.482425 0.769094 nan nan nan nan
60 28 1-FTTransformerConfig nan 0.000000 nan Adam 0.495479 0.767318 8.000000 32.000000 6.000000 GEGLU
56 24 1-FTTransformerConfig nan 0.000000 nan Adam 0.519672 0.767318 8.000000 32.000000 3.000000 GEGLU
80 48 1-FTTransformerConfig nan 0.200000 nan Adam 0.518865 0.765542 4.000000 32.000000 3.000000 GEGLU
74 42 1-FTTransformerConfig nan 0.000000 nan Adam 0.483879 0.763766 8.000000 64.000000 6.000000 LeakyReLU
64 32 1-FTTransformerConfig nan 0.000000 nan Adam 0.575869 0.763766 8.000000 32.000000 8.000000 GEGLU
94 62 1-FTTransformerConfig nan 0.200000 nan Adam 0.484891 0.761989 4.000000 64.000000 3.000000 LeakyReLU
66 34 1-FTTransformerConfig nan 0.000000 nan Adam 0.506116 0.761989 8.000000 32.000000 8.000000 LeakyReLU
52 20 1-FTTransformerConfig nan 0.000000 nan Adam 0.511868 0.761989 4.000000 64.000000 8.000000 GEGLU
96 64 1-FTTransformerConfig nan 0.200000 nan Adam 0.482814 0.760213 4.000000 64.000000 6.000000 GEGLU
110 78 1-FTTransformerConfig nan 0.200000 nan Adam 0.479574 0.758437 8.000000 32.000000 6.000000 LeakyReLU
19 19 0-CategoryEmbeddingModelConfig LeakyReLU 0.000000 1024-512-256 SGD 0.532006 0.756661 nan nan nan nan
124 92 1-FTTransformerConfig nan 0.200000 nan Adam 0.532167 0.756661 8.000000 64.000000 8.000000 GEGLU
86 54 1-FTTransformerConfig nan 0.200000 nan Adam 0.462083 0.754885 4.000000 32.000000 6.000000 LeakyReLU
50 18 1-FTTransformerConfig nan 0.000000 nan Adam 0.503736 0.753108 4.000000 64.000000 6.000000 LeakyReLU
42 10 1-FTTransformerConfig nan 0.000000 nan Adam 0.470982 0.753108 4.000000 32.000000 8.000000 LeakyReLU
34 2 1-FTTransformerConfig nan 0.000000 nan Adam 0.503541 0.751332 4.000000 32.000000 3.000000 LeakyReLU
106 74 1-FTTransformerConfig nan 0.200000 nan Adam 0.504346 0.747780 8.000000 32.000000 3.000000 LeakyReLU
46 14 1-FTTransformerConfig nan 0.000000 nan Adam 0.488356 0.747780 4.000000 64.000000 3.000000 LeakyReLU
54 22 1-FTTransformerConfig nan 0.000000 nan Adam 0.561371 0.740675 4.000000 64.000000 8.000000 LeakyReLU
58 26 1-FTTransformerConfig nan 0.000000 nan Adam 0.494664 0.740675 8.000000 32.000000 3.000000 LeakyReLU
88 56 1-FTTransformerConfig nan 0.200000 nan Adam 0.527474 0.738899 4.000000 32.000000 8.000000 GEGLU
84 52 1-FTTransformerConfig nan 0.200000 nan Adam 0.508179 0.731794 4.000000 32.000000 6.000000 GEGLU
118 86 1-FTTransformerConfig nan 0.200000 nan Adam 0.511033 0.731794 8.000000 64.000000 3.000000 LeakyReLU
120 88 1-FTTransformerConfig nan 0.200000 nan Adam 0.473721 0.731794 8.000000 64.000000 6.000000 GEGLU
98 66 1-FTTransformerConfig nan 0.200000 nan Adam 0.518997 0.731794 4.000000 64.000000 6.000000 LeakyReLU
31 31 0-CategoryEmbeddingModelConfig LeakyReLU 0.200000 256-512-1024 SGD 0.538754 0.731794 nan nan nan nan
40 8 1-FTTransformerConfig nan 0.000000 nan Adam 0.546107 0.731794 4.000000 32.000000 8.000000 GEGLU
4 4 0-CategoryEmbeddingModelConfig ReLU 0.000000 32-64-128 Adam 0.533960 0.728242 nan nan nan nan
70 38 1-FTTransformerConfig nan 0.000000 nan Adam 0.579302 0.726465 8.000000 64.000000 3.000000 LeakyReLU
12 12 0-CategoryEmbeddingModelConfig ReLU 0.200000 32-64-128 Adam 0.508314 0.724689 nan nan nan nan
38 6 1-FTTransformerConfig nan 0.000000 nan Adam 0.538916 0.721137 4.000000 32.000000 6.000000 LeakyReLU
82 50 1-FTTransformerConfig nan 0.200000 nan Adam 0.537538 0.721137 4.000000 32.000000 3.000000 LeakyReLU
122 90 1-FTTransformerConfig nan 0.200000 nan Adam 0.522755 0.719361 8.000000 64.000000 6.000000 LeakyReLU
48 16 1-FTTransformerConfig nan 0.000000 nan Adam 0.471181 0.715808 4.000000 64.000000 6.000000 GEGLU
32 0 1-FTTransformerConfig nan 0.000000 nan Adam 0.550226 0.714032 4.000000 32.000000 3.000000 GEGLU
108 76 1-FTTransformerConfig nan 0.200000 nan Adam 0.523274 0.714032 8.000000 32.000000 6.000000 GEGLU
63 31 1-FTTransformerConfig nan 0.000000 nan SGD 0.591639 0.712256 8.000000 32.000000 6.000000 LeakyReLU
104 72 1-FTTransformerConfig nan 0.200000 nan Adam 0.508801 0.710480 8.000000 32.000000 3.000000 GEGLU
24 24 0-CategoryEmbeddingModelConfig LeakyReLU 0.200000 128-64-32 Adam 0.519161 0.710480 nan nan nan nan
68 36 1-FTTransformerConfig nan 0.000000 nan Adam 0.572089 0.706927 8.000000 64.000000 3.000000 GEGLU
92 60 1-FTTransformerConfig nan 0.200000 nan Adam 0.575852 0.706927 4.000000 64.000000 3.000000 GEGLU
126 94 1-FTTransformerConfig nan 0.200000 nan Adam 0.570989 0.706927 8.000000 64.000000 8.000000 LeakyReLU
44 12 1-FTTransformerConfig nan 0.000000 nan Adam 0.577062 0.705151 4.000000 64.000000 3.000000 GEGLU
79 47 1-FTTransformerConfig nan 0.000000 nan SGD 0.557485 0.703375 8.000000 64.000000 8.000000 LeakyReLU
51 19 1-FTTransformerConfig nan 0.000000 nan SGD 0.550771 0.703375 4.000000 64.000000 6.000000 LeakyReLU
11 11 0-CategoryEmbeddingModelConfig ReLU 0.200000 1024-512-256 SGD 0.555238 0.703375 nan nan nan nan
114 82 1-FTTransformerConfig nan 0.200000 nan Adam 0.487832 0.701599 8.000000 32.000000 8.000000 LeakyReLU
90 58 1-FTTransformerConfig nan 0.200000 nan Adam 0.579668 0.699822 4.000000 32.000000 8.000000 LeakyReLU
3 3 0-CategoryEmbeddingModelConfig ReLU 0.000000 1024-512-256 SGD 0.572410 0.696270 nan nan nan nan
112 80 1-FTTransformerConfig nan 0.200000 nan Adam 0.553881 0.692718 8.000000 32.000000 8.000000 GEGLU
15 15 0-CategoryEmbeddingModelConfig ReLU 0.200000 256-512-1024 SGD 0.562511 0.685613 nan nan nan nan
35 3 1-FTTransformerConfig nan 0.000000 nan SGD 0.581403 0.685613 4.000000 32.000000 3.000000 LeakyReLU
45 13 1-FTTransformerConfig nan 0.000000 nan SGD 0.597738 0.685613 4.000000 64.000000 3.000000 GEGLU
100 68 1-FTTransformerConfig nan 0.200000 nan Adam 0.584579 0.683837 4.000000 64.000000 8.000000 GEGLU
83 51 1-FTTransformerConfig nan 0.200000 nan SGD 0.662541 0.680284 4.000000 32.000000 3.000000 LeakyReLU
127 95 1-FTTransformerConfig nan 0.200000 nan SGD 0.614641 0.676732 8.000000 64.000000 8.000000 LeakyReLU
101 69 1-FTTransformerConfig nan 0.200000 nan SGD 0.579955 0.674956 4.000000 64.000000 8.000000 GEGLU
102 70 1-FTTransformerConfig nan 0.200000 nan Adam 0.585392 0.671403 4.000000 64.000000 8.000000 LeakyReLU
27 27 0-CategoryEmbeddingModelConfig LeakyReLU 0.200000 1024-512-256 SGD 0.594700 0.667851 nan nan nan nan
7 7 0-CategoryEmbeddingModelConfig ReLU 0.000000 256-512-1024 SGD 0.598617 0.666075 nan nan nan nan
121 89 1-FTTransformerConfig nan 0.200000 nan SGD 0.632152 0.666075 8.000000 64.000000 6.000000 GEGLU
76 44 1-FTTransformerConfig nan 0.000000 nan Adam 0.641684 0.666075 8.000000 64.000000 8.000000 GEGLU
103 71 1-FTTransformerConfig nan 0.200000 nan SGD 0.616750 0.666075 4.000000 64.000000 8.000000 LeakyReLU
91 59 1-FTTransformerConfig nan 0.200000 nan SGD 0.634522 0.664298 4.000000 32.000000 8.000000 LeakyReLU
59 27 1-FTTransformerConfig nan 0.000000 nan SGD 0.624750 0.664298 8.000000 32.000000 3.000000 LeakyReLU
107 75 1-FTTransformerConfig nan 0.200000 nan SGD 0.637458 0.657194 8.000000 32.000000 3.000000 LeakyReLU
69 37 1-FTTransformerConfig nan 0.000000 nan SGD 0.636728 0.657194 8.000000 64.000000 3.000000 GEGLU
55 23 1-FTTransformerConfig nan 0.000000 nan SGD 0.613378 0.657194 4.000000 64.000000 8.000000 LeakyReLU
5 5 0-CategoryEmbeddingModelConfig ReLU 0.000000 32-64-128 SGD 0.670955 0.655417 nan nan nan nan
117 85 1-FTTransformerConfig nan 0.200000 nan SGD 0.629454 0.655417 8.000000 64.000000 3.000000 GEGLU
97 65 1-FTTransformerConfig nan 0.200000 nan SGD 0.645757 0.655417 4.000000 64.000000 6.000000 GEGLU
87 55 1-FTTransformerConfig nan 0.200000 nan SGD 0.646177 0.651865 4.000000 32.000000 6.000000 LeakyReLU
23 23 0-CategoryEmbeddingModelConfig LeakyReLU 0.000000 256-512-1024 SGD 0.639443 0.650089 nan nan nan nan
39 7 1-FTTransformerConfig nan 0.000000 nan SGD 0.651099 0.646536 4.000000 32.000000 6.000000 LeakyReLU
125 93 1-FTTransformerConfig nan 0.200000 nan SGD 0.624359 0.646536 8.000000 64.000000 8.000000 GEGLU
65 33 1-FTTransformerConfig nan 0.000000 nan SGD 0.597288 0.644760 8.000000 32.000000 8.000000 GEGLU
47 15 1-FTTransformerConfig nan 0.000000 nan SGD 0.666151 0.644760 4.000000 64.000000 3.000000 LeakyReLU
49 17 1-FTTransformerConfig nan 0.000000 nan SGD 0.639839 0.642984 4.000000 64.000000 6.000000 GEGLU
123 91 1-FTTransformerConfig nan 0.200000 nan SGD 0.628552 0.641208 8.000000 64.000000 6.000000 LeakyReLU
75 43 1-FTTransformerConfig nan 0.000000 nan SGD 0.619922 0.641208 8.000000 64.000000 6.000000 LeakyReLU
85 53 1-FTTransformerConfig nan 0.200000 nan SGD 0.655388 0.641208 4.000000 32.000000 6.000000 GEGLU
89 57 1-FTTransformerConfig nan 0.200000 nan SGD 0.635567 0.637655 4.000000 32.000000 8.000000 GEGLU
93 61 1-FTTransformerConfig nan 0.200000 nan SGD 0.658716 0.635879 4.000000 64.000000 3.000000 GEGLU
71 39 1-FTTransformerConfig nan 0.000000 nan SGD 0.646253 0.634103 8.000000 64.000000 3.000000 LeakyReLU
67 35 1-FTTransformerConfig nan 0.000000 nan SGD 0.667418 0.632327 8.000000 32.000000 8.000000 LeakyReLU
81 49 1-FTTransformerConfig nan 0.200000 nan SGD 0.664191 0.628774 4.000000 32.000000 3.000000 GEGLU
119 87 1-FTTransformerConfig nan 0.200000 nan SGD 0.665687 0.628774 8.000000 64.000000 3.000000 LeakyReLU
113 81 1-FTTransformerConfig nan 0.200000 nan SGD 0.651145 0.628774 8.000000 32.000000 8.000000 GEGLU
53 21 1-FTTransformerConfig nan 0.000000 nan SGD 0.672665 0.628774 4.000000 64.000000 8.000000 GEGLU
25 25 0-CategoryEmbeddingModelConfig LeakyReLU 0.200000 128-64-32 SGD 0.662720 0.625222 nan nan nan nan
111 79 1-FTTransformerConfig nan 0.200000 nan SGD 0.633410 0.623446 8.000000 32.000000 6.000000 LeakyReLU
43 11 1-FTTransformerConfig nan 0.000000 nan SGD 0.635329 0.621670 4.000000 32.000000 8.000000 LeakyReLU
99 67 1-FTTransformerConfig nan 0.200000 nan SGD 0.628636 0.616341 4.000000 64.000000 6.000000 LeakyReLU
41 9 1-FTTransformerConfig nan 0.000000 nan SGD 0.639925 0.616341 4.000000 32.000000 8.000000 GEGLU
109 77 1-FTTransformerConfig nan 0.200000 nan SGD 0.651506 0.614565 8.000000 32.000000 6.000000 GEGLU
1 1 0-CategoryEmbeddingModelConfig ReLU 0.000000 128-64-32 SGD 0.665929 0.612789 nan nan nan nan
105 73 1-FTTransformerConfig nan 0.200000 nan SGD 0.658312 0.605684 8.000000 32.000000 3.000000 GEGLU
37 5 1-FTTransformerConfig nan 0.000000 nan SGD 0.652759 0.605684 4.000000 32.000000 6.000000 GEGLU
73 41 1-FTTransformerConfig nan 0.000000 nan SGD 0.659291 0.598579 8.000000 64.000000 6.000000 GEGLU
33 1 1-FTTransformerConfig nan 0.000000 nan SGD 0.648887 0.596803 4.000000 32.000000 3.000000 GEGLU
17 17 0-CategoryEmbeddingModelConfig LeakyReLU 0.000000 128-64-32 SGD 0.710187 0.589698 nan nan nan nan
77 45 1-FTTransformerConfig nan 0.000000 nan SGD 0.648749 0.589698 8.000000 64.000000 8.000000 GEGLU
29 29 0-CategoryEmbeddingModelConfig LeakyReLU 0.200000 32-64-128 SGD 0.719664 0.582593 nan nan nan nan
13 13 0-CategoryEmbeddingModelConfig ReLU 0.200000 32-64-128 SGD 0.778426 0.555950 nan nan nan nan
61 29 1-FTTransformerConfig nan 0.000000 nan SGD 0.689890 0.552398 8.000000 32.000000 6.000000 GEGLU
57 25 1-FTTransformerConfig nan 0.000000 nan SGD 0.690501 0.539964 8.000000 32.000000 3.000000 GEGLU
21 21 0-CategoryEmbeddingModelConfig LeakyReLU 0.000000 32-64-128 SGD 0.726256 0.539964 nan nan nan nan
95 63 1-FTTransformerConfig nan 0.200000 nan SGD 0.701837 0.502664 4.000000 64.000000 3.000000 LeakyReLU
115 83 1-FTTransformerConfig nan 0.200000 nan SGD 0.680208 0.502664 8.000000 32.000000 8.000000 LeakyReLU
78 46 1-FTTransformerConfig nan 0.000000 nan Adam 0.693390 0.493783 8.000000 64.000000 8.000000 LeakyReLU
9 9 0-CategoryEmbeddingModelConfig ReLU 0.200000 128-64-32 SGD 0.781076 0.433393 nan nan nan nan
tuner_df.best_model.evaluate(test)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.8173333406448364     │
│         test_loss             0.38250666856765747    │
└───────────────────────────┴───────────────────────────┘
[{'test_loss': 0.38250666856765747, 'test_accuracy': 0.8173333406448364}]

训练后,最佳模型将保存在输出变量中,名称为 "best_model"。因此,如果您对结果满意并希望在未来使用该模型,可以调用 "save_model" 进行保存。

tuner_df.best_model.save_model("best_model", inference_only=True)
2024-07-20 12:58:01,015 - {pytorch_tabular.tabular_model:1572} - WARNING - Directory is not empty. Overwriting the 
contents.                                                                                                          
# 加载已保存的模型
#来自pytorch_tabular的TabularModel
#loaded_model = TabularModel.load_model("best_model")
#加载的模型.评估(测试)