使用
PyTorch Tabular 自带智能默认设置,使得开始使用表格深度学习变得简单。然而,它也提供了灵活性,可以根据您的需求自定义模型和流水线。
以下是一个简单的示例,展示如何使用 PyTorch Tabular 来训练模型、在新数据上进行评估、生成预测,以及保存和加载模型。
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig
from pytorch_tabular.config import (
DataConfig,
OptimizerConfig,
TrainerConfig,
)
data_config = DataConfig(
target=[
"target"
], # 目标应始终为列表。
continuous_cols=num_col_names,
categorical_cols=cat_col_names,
)
trainer_config = TrainerConfig(
auto_lr_find=True, # 运行LRFinder以自动推导学习率
batch_size=1024,
max_epochs=100,
)
optimizer_config = OptimizerConfig()
model_config = CategoryEmbeddingModelConfig(
task="classification",
layers="1024-512-512", # 每层中的节点数
activation="LeakyReLU", # 每层之间的激活函数
learning_rate=1e-3,
)
tabular_model = TabularModel(
data_config=data_config,
model_config=model_config,
optimizer_config=optimizer_config,
trainer_config=trainer_config,
)
tabular_model.fit(train=train, validation=val)
result = tabular_model.evaluate(test)
pred_df = tabular_model.predict(test)
tabular_model.save_model("examples/basic")
loaded_model = TabularModel.load_model("examples/basic")
如需更详细的教程和操作指南,请参阅 教程 和 操作指南 部分。