使用

PyTorch Tabular 自带智能默认设置,使得开始使用表格深度学习变得简单。然而,它也提供了灵活性,可以根据您的需求自定义模型和流水线。

以下是一个简单的示例,展示如何使用 PyTorch Tabular 来训练模型、在新数据上进行评估、生成预测,以及保存和加载模型。

from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig
from pytorch_tabular.config import (
    DataConfig,
    OptimizerConfig,
    TrainerConfig,
)

data_config = DataConfig(
    target=[
        "target"
    ],  # 目标应始终为列表。
    continuous_cols=num_col_names,
    categorical_cols=cat_col_names,
)
trainer_config = TrainerConfig(
    auto_lr_find=True,  # 运行LRFinder以自动推导学习率
    batch_size=1024,
    max_epochs=100,
)
optimizer_config = OptimizerConfig()

model_config = CategoryEmbeddingModelConfig(
    task="classification",
    layers="1024-512-512",  # 每层中的节点数
    activation="LeakyReLU",  # 每层之间的激活函数
    learning_rate=1e-3,
)

tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
)
tabular_model.fit(train=train, validation=val)
result = tabular_model.evaluate(test)
pred_df = tabular_model.predict(test)
tabular_model.save_model("examples/basic")
loaded_model = TabularModel.load_model("examples/basic")

如需更详细的教程和操作指南,请参阅 教程操作指南 部分。