Skip to content

实验跟踪在机器学习中是至关重要的,因为它使数据科学家和研究人员能够有效地管理和重现他们的实验。通过跟踪实验的各个方面,如超参数、模型架构和训练数据,可以更容易地理解和解释结果。实验跟踪还允许团队成员之间更好的协作和知识共享,因为它提供了一个集中式的实验库及其相关的元数据。此外,跟踪实验有助于调试和故障排除,因为它可以识别导致成功或失败结果的特定设置或条件。总体而言,实验跟踪在确保机器学习工作流程的透明性、可重现性和持续改进中发挥着至关重要的作用。

现在让我们看看如何使用Weights & Biases和PyTorch Tabular免费获取所有这些好处。

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
import random
from pytorch_tabular.utils import load_covertype_dataset, print_metrics
import pandas as pd
import wandb

# %加载自动重新加载扩展
# %自动重新加载 2
wandb.login()
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: manujosephv. Use `wandb login --relogin` to force relogin

True
data, cat_col_names, num_col_names, target_col = load_covertype_dataset()
train, test = train_test_split(data, random_state=42)
train, val = train_test_split(train, random_state=42)

导入库

from pytorch_tabular import TabularModel
from pytorch_tabular.models import (
    CategoryEmbeddingModelConfig,
    FTTransformerConfig,
    TabNetModelConfig,
    GANDALFConfig,
)
from pytorch_tabular.config import (
    DataConfig,
    OptimizerConfig,
    TrainerConfig,
    ExperimentConfig,
)
from pytorch_tabular.models.common.heads import LinearHeadConfig

常见配置

data_config = DataConfig(
    target=[
        target_col
    ],  # 目标应始终为一个列表。对于回归任务,支持多目标;而多任务分类功能尚未实现。
    continuous_cols=num_col_names,
    categorical_cols=cat_col_names,
)
trainer_config = TrainerConfig(
    auto_lr_find=True,  # 运行LRFinder以自动推导学习率
    batch_size=1024,
    max_epochs=100,
    early_stopping="valid_loss",  # 监视有效损失以进行提前停止
    early_stopping_mode="min",  # 将模式设置为 min,因为在 val_loss 中,数值越低越好。
    early_stopping_patience=5,  # 在终止之前,将等待的降解训练的轮次数
    checkpoints="valid_loss",  # 保存最佳检查点监控验证损失
    load_best=True,  # 训练完成后,加载最佳检查点。
)
optimizer_config = OptimizerConfig()

head_config = LinearHeadConfig(
    layers="",  # 头部没有额外的层,仅有一个映射层输出到output_dim。
    dropout=0.1,
    initialization="kaiming",
).__dict__  # 转换为字典以传递给模型配置(OmegaConf不接受对象)

EXP_PROJECT_NAME = "pytorch-tabular-covertype"

类别嵌入模型

model_config = CategoryEmbeddingModelConfig(
    task="classification",
    layers="1024-512-512",  # 每一层的节点数量
    activation="LeakyReLU",  # 各层之间的激活
    learning_rate=1e-3,
    head="LinearHead",  # 线性磁头
    head_config=head_config,  # 线性磁头配置
)

experiment_config = ExperimentConfig(
    project_name=EXP_PROJECT_NAME,
    run_name="CategoryEmbeddingModel",
    exp_watch="gradients",
    log_target="wandb",
    log_logits=True,
)

tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
    experiment_config=experiment_config,
    verbose=False,
    suppress_lightning_logger=True,
)
tabular_model.fit(train=train, validation=val)
/home/manujosephv/pytorch_tabular/src/pytorch_tabular/models/base_model.py:164: UserWarning: Plotly is not installed. Please install plotly to log logits. You can install plotly using pip install plotly or install PyTorch Tabular using pip install pytorch-tabular[extra]
  warnings.warn(

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112782611356427, max=1.0…
Tracking run with wandb version 0.16.1
Run data is saved locally in ./wandb/run-20240106_114035-hvalkv16
wandb: logging graph, to disable use `wandb.watch(log_graph=False)`
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.

Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓
┃    Name              Type                       Params ┃
┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩
│ 0 │ _backbone        │ CategoryEmbeddingBackbone │  823 K │
│ 1 │ _embedding_layer │ Embedding1dLayer          │    896 │
│ 2 │ head             │ LinearHead                │  3.6 K │
│ 3 │ loss             │ CrossEntropyLoss          │      0 │
└───┴──────────────────┴───────────────────────────┴────────┘
Trainable params: 827 K                                                                                            
Non-trainable params: 0                                                                                            
Total params: 827 K                                                                                                
Total estimated model params size (MB): 3                                                                          
Output()


<pytorch_lightning.trainer.trainer.Trainer at 0x7f14867a4850>
result = tabular_model.evaluate(test)
Output()
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.9159672856330872     │
│         test_loss             0.21389885246753693    │
└───────────────────────────┴───────────────────────────┘


# 虽然实验应能自动结束,但在开始新实验之前显式调用它会更安全。
wandb.finish()
wandb: WARNING Source type is set to 'repo' but some required information is missing from the environment. A job will not be created from this run. See https://docs.wandb.ai/guides/launch/create-job

VBox(children=(Label(value='0.005 MB of 0.005 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

Run history:


epoch▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇███
test_accuracy
test_loss
train_accuracy▁▃▄▄▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇███████████████
train_loss█▇▆▅▅▅▄▄▃▄▃▄▄▃▂▃▃▃▃▃▂▂▂▂▃▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁
trainer/global_step▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
valid_accuracy▁▂▃▃▄▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇█▇██▇███████▇██▇
valid_loss█▇▆▆▅▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▂▁▁▂

Run summary:


epoch52
test_accuracy0.91597
test_loss0.2139
train_accuracy0.91938
train_loss0.20502
trainer/global_step16640
valid_accuracy0.89782
valid_loss0.24614

View run CategoryEmbeddingModel_5 at: https://wandb.ai/manujosephv/pytorch-tabular-covertype/runs/hvalkv16
Synced 6 W&B file(s), 1 media file(s), 0 artifact file(s) and 0 other file(s)
Find logs at: ./wandb/run-20240106_114035-hvalkv16/logs

FT 变换器

model_config = FTTransformerConfig(
    task="classification",
    num_attn_blocks=3,
    num_heads=4,
    learning_rate=1e-3,
    head="LinearHead",  # 线性磁头
    head_config=head_config,  # 线性磁头配置
)

experiment_config = ExperimentConfig(
    project_name=EXP_PROJECT_NAME,
    run_name="FTTransformer",
    exp_watch="gradients",
    log_target="wandb",
    log_logits=True,
)
tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
    experiment_config=experiment_config,
    verbose=False,
    suppress_lightning_logger=True,
)
tabular_model.fit(train=train, validation=val)
/home/manujosephv/pytorch_tabular/src/pytorch_tabular/models/base_model.py:164: UserWarning: Plotly is not installed. Please install plotly to log logits. You can install plotly using pip install plotly or install PyTorch Tabular using pip install pytorch-tabular[extra]
  warnings.warn(

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112755477531917, max=1.0…
Tracking run with wandb version 0.16.1
Run data is saved locally in ./wandb/run-20240106_120910-k2qdphzr
wandb: logging graph, to disable use `wandb.watch(log_graph=False)`
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.

Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓
┃    Name              Type                   Params ┃
┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩
│ 0 │ _backbone        │ FTTransformerBackbone │ 86.5 K │
│ 1 │ _embedding_layer │ Embedding2dLayer      │  2.2 K │
│ 2 │ _head            │ LinearHead            │    231 │
│ 3 │ loss             │ CrossEntropyLoss      │      0 │
└───┴──────────────────┴───────────────────────┴────────┘
Trainable params: 89.0 K                                                                                           
Non-trainable params: 0                                                                                            
Total params: 89.0 K                                                                                               
Total estimated model params size (MB): 0                                                                          
Output()


<pytorch_lightning.trainer.trainer.Trainer at 0x7f1486720910>
result = tabular_model.evaluate(test)
Output()
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.9120706915855408     │
│         test_loss             0.21334321796894073    │
└───────────────────────────┴───────────────────────────┘


wandb.finish()
VBox(children=(Label(value='0.010 MB of 0.010 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))
wandb: WARNING Source type is set to 'repo' but some required information is missing from the environment. A job will not be created from this run. See https://docs.wandb.ai/guides/launch/create-job

Run history:


epoch▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
test_accuracy
test_loss
train_accuracy▁▃▄▄▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇██████████████
train_loss█▇▆▅▆▄▄▄▄▄▃▄▃▃▂▃▃▃▂▂▂▁▃▃▃▄▃▂▂▄▂▂▁▂▁▃▂▃▂▁
trainer/global_step▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
valid_accuracy▁▃▄▄▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇████████████████
valid_loss█▆▅▅▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁

Run summary:


epoch47
test_accuracy0.91207
test_loss0.21334
train_accuracy0.88804
train_loss0.23555
trainer/global_step15040
valid_accuracy0.91161
valid_loss0.21692

View run FTTransformer_1 at: https://wandb.ai/manujosephv/pytorch-tabular-covertype/runs/k2qdphzr
Synced 6 W&B file(s), 1 media file(s), 0 artifact file(s) and 0 other file(s)
Find logs at: ./wandb/run-20240106_120910-k2qdphzr/logs

甘道夫

model_config = GANDALFConfig(
    task="classification",
    gflu_stages=10,
    learning_rate=1e-3,
    head="LinearHead",  # 线性磁头
    head_config=head_config,  # 线性磁头配置
)

experiment_config = ExperimentConfig(
    project_name=EXP_PROJECT_NAME,
    run_name="GANDALF",
    exp_watch="gradients",
    log_target="wandb",
    log_logits=True,
)
tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
    experiment_config=experiment_config,
    verbose=False,
    suppress_lightning_logger=True,
)
tabular_model.fit(train=train, validation=val)
/home/manujosephv/pytorch_tabular/src/pytorch_tabular/models/base_model.py:164: UserWarning: Plotly is not installed. Please install plotly to log logits. You can install plotly using pip install plotly or install PyTorch Tabular using pip install pytorch-tabular[extra]
  warnings.warn(

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01111247184453532, max=1.0)…
Tracking run with wandb version 0.16.1
Run data is saved locally in ./wandb/run-20240106_123420-9kg2s1qg
Syncing run GANDALF_1 to Weights & Biases (docs)
wandb: logging graph, to disable use `wandb.watch(log_graph=False)`
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.

Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓
┃    Name              Type              Params ┃
┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩
│ 0 │ _backbone        │ GANDALFBackbone  │ 70.7 K │
│ 1 │ _embedding_layer │ Embedding1dLayer │    896 │
│ 2 │ _head            │ Sequential       │    252 │
│ 3 │ loss             │ CrossEntropyLoss │      0 │
└───┴──────────────────┴──────────────────┴────────┘
Trainable params: 71.9 K                                                                                           
Non-trainable params: 0                                                                                            
Total params: 71.9 K                                                                                               
Total estimated model params size (MB): 0                                                                          
Output()


<pytorch_lightning.trainer.trainer.Trainer at 0x7f14866c8a50>
result = tabular_model.evaluate(test)
Output()
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.8669493794441223     │
│         test_loss             0.32519233226776123    │
└───────────────────────────┴───────────────────────────┘


wandb.finish()
wandb: WARNING Source type is set to 'repo' but some required information is missing from the environment. A job will not be created from this run. See https://docs.wandb.ai/guides/launch/create-job

VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

Run history:


epoch▁▁▁▁▂▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇█
test_accuracy
test_loss
train_accuracy▁▆▇███████
train_loss█▇▄▃▃▂▁▃▂▂▂▂▃▁▂▃▁▂▂▂▂▂▂▂▂▁▂▂▂▁▂▁▂▃▂▁▂▁▂▄
trainer/global_step▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
valid_accuracy▁▆▆▆▇█▇▇██
valid_loss█▃▂▂▁▁▁▁▁▁

Run summary:


epoch10
test_accuracy0.86695
test_loss0.32519
train_accuracy0.86358
train_loss0.42013
trainer/global_step3200
valid_accuracy0.86707
valid_loss0.32778

View run GANDALF_1 at: https://wandb.ai/manujosephv/pytorch-tabular-covertype/runs/9kg2s1qg
Synced 6 W&B file(s), 1 media file(s), 0 artifact file(s) and 0 other file(s)
Find logs at: ./wandb/run-20240106_123420-9kg2s1qg/logs

TabNet 模型

model_config = TabNetModelConfig(
    task="classification",
    learning_rate=1e-5,
    n_d=16,
    n_a=16,
    n_steps=4,
    head="LinearHead",  # 线性磁头
    head_config=head_config,  # 线性磁头配置
)

experiment_config = ExperimentConfig(
    project_name=EXP_PROJECT_NAME,
    run_name="TabNet",
    exp_watch="gradients",
    log_target="wandb",
    log_logits=True,
)
tabular_model = TabularModel(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
    experiment_config=experiment_config,
    verbose=False,
    suppress_lightning_logger=True,
)
tabular_model.fit(train=train, validation=val)
/home/manujosephv/pytorch_tabular/src/pytorch_tabular/models/base_model.py:164: UserWarning: Plotly is not installed. Please install plotly to log logits. You can install plotly using pip install plotly or install PyTorch Tabular using pip install pytorch-tabular[extra]
  warnings.warn(

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112859611037291, max=1.0…
Tracking run with wandb version 0.16.1
Run data is saved locally in ./wandb/run-20240106_124017-iw6q00dk
Syncing run TabNet_1 to Weights & Biases (docs)
wandb: logging graph, to disable use `wandb.watch(log_graph=False)`
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.

Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]
┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓
┃    Name              Type              Params ┃
┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩
│ 0 │ _embedding_layer │ Identity         │      0 │
│ 1 │ _backbone        │ TabNetBackbone   │ 29.3 K │
│ 2 │ _head            │ Identity         │      0 │
│ 3 │ loss             │ CrossEntropyLoss │      0 │
└───┴──────────────────┴──────────────────┴────────┘
Trainable params: 29.3 K                                                                                           
Non-trainable params: 0                                                                                            
Total params: 29.3 K                                                                                               
Total estimated model params size (MB): 0                                                                          
Output()


<pytorch_lightning.trainer.trainer.Trainer at 0x7f1487491310>
result = tabular_model.evaluate(test)
Output()
/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.7186701893806458     │
│         test_loss             0.6771128177642822     │
└───────────────────────────┴───────────────────────────┘


wandb.finish()
wandb: WARNING Source type is set to 'repo' but some required information is missing from the environment. A job will not be created from this run. See https://docs.wandb.ai/guides/launch/create-job

VBox(children=(Label(value='0.013 MB of 0.013 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

Run history:


epoch▁▁▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇█
test_accuracy
test_loss
train_accuracy▁▄▅▇█▆
train_loss▇▄▃▂▃▃▃▃▁▄█▄▅▄▄▃▂▂▂▄▃▂▂▂▃▂▂▂▂▂▂▂▂▁▄▄▅▃
trainer/global_step▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
valid_accuracy█▇▁█▅▄
valid_loss▁▃█▁▂█

Run summary:


epoch6
test_accuracy0.71867
test_loss0.67711
train_accuracy0.70094
train_loss0.7142
trainer/global_step1920
valid_accuracy0.65582
valid_loss0.95996

View run TabNet_1 at: https://wandb.ai/manujosephv/pytorch-tabular-covertype/runs/iw6q00dk
Synced 6 W&B file(s), 1 media file(s), 0 artifact file(s) and 0 other file(s)
Find logs at: ./wandb/run-20240106_124017-iw6q00dk/logs

访问实验

我们可以访问运行情况 @ https://wandb.ai/manujosephv/pytorch-tabular-covertype/

我们还可以检查模型每个组件中的梯度流以便进行调试。