导入库¶
测试时增强¶
测试时增强(TTA)是计算机视觉中一种流行的技术。TTA的目的是通过在推理阶段使用数据增强来提高模型的准确性。TTA背后的理念很简单:对于每一张测试图像,我们创建多个与原始图像稍有不同的版本(例如,裁剪或翻转)。接下来,我们对测试图像及其创建的副本进行预测,并对每张图像的多个版本的模型预测结果进行平均。这通常有助于提高准确性,而不管基础模型如何。
有关更多细节,请参考此链接:表格数据的测试时增强
data_config = DataConfig(
target=[
target_col
], # 目标应始终为一个列表。仅在回归任务中支持多目标。多任务分类尚未实现。
continuous_cols=num_col_names,
categorical_cols=cat_col_names,
)
trainer_config = TrainerConfig(
batch_size=1024,
max_epochs=100,
early_stopping="valid_loss", # 监视有效的损失以进行提前停止
early_stopping_mode="min", # 将模式设置为min,因为在验证损失(val_loss)中,数值越低越好。
early_stopping_patience=5, # 在终止之前等待的退化训练的轮次数
checkpoints="valid_loss", # 保存最佳检查点监控验证损失
load_best=True, # 训练后,加载最佳检查点
# progress_bar="none", # Turning off Progress bar
# trainer_kwargs=dict(
# enable_model_summary=False 关闭模型摘要
# )
)
optimizer_config = OptimizerConfig()
head_config = LinearHeadConfig(
layers="", dropout=0.1, initialization="kaiming" # 头部没有额外的层,仅包含一个映射层,输出维度为output_dim。
).__dict__ # 转换为字典以传递给模型配置(OmegaConf不接受对象)
model1_config = CategoryEmbeddingModelConfig(
task="classification",
layers="1024-512-512", # 每层节点的数量
activation="LeakyReLU", # 各层之间的激活
learning_rate=1e-3,
head="LinearHead", # 线性磁头
head_config=head_config, # 线性磁头配置
)
model_config = CategoryEmbeddingModelConfig(
task="classification",
layers="1024-512-512", # 每层节点的数量
activation="LeakyReLU", # 各层之间的激活
learning_rate=1e-3,
head="LinearHead", # 线性磁头
head_config=head_config, # 线性磁头配置
)
tabular_model = TabularModel(
data_config=data_config,
model_config=model_config,
optimizer_config=optimizer_config,
trainer_config=trainer_config,
verbose=False
)