Skip to content

训练

训练深度学习模型可能会变得非常复杂。PyTorch Tabular 通过继承 PyTorch Lightning,将整个工作负载转移到底层的 PyTorch Lightning 框架上。它的设计旨在使模型训练变得轻而易举,同时赋予你灵活性,让你能够定制训练过程。

PyTorch Tabular 中的训练器直接或间接地继承了 PyTorch Lightning 训练器的所有功能。

基本用法

你最常设置的参数包括:

  • batch_size: int: 每个训练批次中的样本数量。默认为 64
  • max_epochs: int: 要运行的最大周期数。在启用早停的情况下,这是最大值;在没有早停的情况下,这是将要运行的周期数。默认为 10
  • devices: (Optional[int]): 用于训练的设备数量(整数)。-1 表示使用所有可用设备。默认使用所有可用设备(-1)

  • accelerator: Optional[str]: 用于训练的加速器。可以是 'cpu'、'gpu'、'tpu'、'ipu'、'auto' 之一。默认为 'auto'。

  • load_best: int: 标志,用于在训练期间加载保存的最佳模型。如果关闭了检查点保存,则此项将被忽略。默认为 True

使用示例

trainer_config = TrainerConfig(batch_size=64, max_epochs=10, accelerator="auto")

PyTorch Tabular 默认使用早停机制,并监控 valid_loss 以停止训练。检查点保存也默认开启,它会监控 valid_loss 并将最佳模型保存在 saved_models 文件夹中。所有这些都可以在下一节中进行配置。

高级用法

早停和检查点保存

早停默认开启。但你可以通过将 early_stopping 设置为 None 来关闭它。如果你想监控其他指标,只需在 early_stopping 参数中提供该指标名称。控制早停的其他几个参数包括:

  • early_stopping_min_delta: float: 损失/指标中被视为改进的最小增量。默认为 0.001
  • early_stopping_mode: str: 损失/指标应优化的方向。选项为 maxmin。默认为 min
  • early_stopping_patience: int: 在没有进一步改进损失/指标的情况下等待的周期数。默认为 3
  • min_epochs: int: 要运行的最小周期数。无论停止标准如何,都会运行这么多周期。默认为 1

检查点保存也默认开启,要关闭它,可以将 checkpoints 参数设置为 None。如果你想监控其他指标,只需在 early_stopping 参数中提供该指标名称。控制检查点保存的其他几个参数包括:

  • checkpoints_path: str: 保存模型的路径。默认为 saved_models
  • checkpoints_mode: str: 损失/指标应优化的方向。选项为 maxmin。默认为 min
  • checkpoints_save_top_k: int: 要保存的最佳模型数量。如果你想保存多个最佳模型,可以将此参数设置为 >1。默认为 1

Note

确保你要跟踪的指标/损失名称与日志中的名称完全匹配。推荐的方法是运行一个模型并评估结果。从结果字典中,你可以选择一个键来在训练期间跟踪。

学习率查找器

首先在这篇论文 Cyclical Learning Rates for Training Neural Networks 中提出,随后被 fast.ai 推广,这是一种无需昂贵搜索即可达到最优学习率附近的技术。PyTorch Tabular 允许你使用论文中提出的方法找到最佳学习率,并自动将其用于训练网络。所有这些都可以通过一个简单的标志 auto_lr_find 开启。

我们还可以使用 [pytorch_tabular.TabularModel.find_learning_rate] 作为一个单独的步骤来运行学习率查找器。

控制梯度/优化

在训练过程中,有时你可能需要对梯度优化过程进行更严格的控制。例如,如果梯度爆炸,你可能希望在每次更新前裁剪梯度值。gradient_clip_val 允许你这样做。

有时,你可能希望在执行反向传播之前跨多个批次累积梯度(可能是因为较大的批次大小不适合你的 GPU)。PyTorch Tabular 允许你通过 accumulate_grad_batches 来实现这一点。

调试

很多时候,你需要调试模型,看看为什么它没有按预期表现。甚至在开发新模型时,你也需要大量调试模型。PyTorch Lightning 为此提供了一些功能,PyTorch Tabular 也采用了这些功能。 为了找出性能瓶颈,我们可以使用:

  • profiler: Optional[str]: 在训练过程中分析各个步骤,以帮助识别瓶颈。可选值为:None simple advanced。默认为 None

为了检查整个设置是否无误运行,我们可以使用:

  • fast_dev_run: Optional[str]: 快速调试验证运行。默认为 False

如果模型学习不正常:

  • overfit_batches: float: 使用训练集的这部分数据。如果不为零,将使用相同的训练集进行验证和测试。如果训练数据加载器设置了 shuffle=True,Lightning 会自动禁用它。适用于快速调试或有意过拟合。默认为 0

  • track_grad_norm: bool: 仅在设置实验跟踪时使用。在日志记录器中跟踪和记录梯度范数。默认值为 -1 表示不跟踪。1 表示 L1 范数,2 表示 L2 范数,依此类推。默认为 False。如果梯度范数迅速降至零,则存在问题。

使用完整的 PyTorch Lightning Trainer

要充分发挥 PyTorch Lightning Trainer 的潜力,可以使用 trainer_kwargs 参数。这将允许你传递 PyTorch Lightning Trainer 支持的任何参数。完整的文档可以在这里找到

pytorch_tabular.config.TrainerConfig dataclass

训练器配置.

Parameters:

Name Type Description Default
batch_size int

每个训练批次中的样本数量

64
data_aware_init_batch_size int

数据感知初始化时每个训练批次中的样本数量, 适用时默认值为2000

2000
fast_dev_run bool

如果设置为n(整数)则运行n个批次,如果设置为True则运行1个批次, 用于查找训练、验证和测试中的任何错误(即一种单元测试).

False
max_epochs int

要运行的最大周期数

10
min_epochs Optional[int]

强制训练至少这些周期数.默认值为1

1
max_time Optional[int]

经过此时间后停止训练.默认禁用(None)

None
accelerator Optional[str]

用于训练的加速器.可以是以下之一: 'cpu','gpu','tpu','ipu', 'mps', 'auto'.默认为'auto'. 可选值为:[cpu,gpu,tpu,ipu,'mps',auto].

'auto'
devices Optional[int]

用于训练的设备数量(整数).-1表示使用所有可用设备. 默认情况下,使用所有可用设备(-1)

-1
devices_list Optional[List[int]]

用于训练的设备列表(列表).如果指定, 优先于devices参数.默认为None

None
accumulate_grad_batches int

每k个批次或按字典设置累积梯度.训练器 也会在最后一个不可整除的步骤数上调用optimizer.step().

1
auto_lr_find bool

在调用trainer.tune()时运行学习率查找算法, 以找到最佳初始学习率.

False
auto_select_gpus bool

如果启用且devices为整数,则自动选择可用GPU. 这在GPU配置为'独占模式'时特别有用,即一次只有一个进程可以访问它们.

True
check_val_every_n_epoch int

每n个训练周期检查一次验证.

1
gradient_clip_val float

梯度裁剪值

0.0
overfit_batches float

使用训练集的此部分数据.如果不为零,将使用相同的 训练集进行验证和测试.如果训练数据加载器的shuffle=True,Lightning 将自动禁用它.对于快速调试或故意过拟合很有用.

0.0
deterministic bool

如果为真,启用cudnn.deterministic.可能会使系统变慢,但 确保可重复性.

False
profiler Optional[str]

在训练期间分析各个步骤并协助识别 瓶颈.可以是None、simple或advanced、pytorch.可选值为: [None,simple,advanced,pytorch].

None
early_stopping Optional[str]

需要监控的损失/指标以进行早停.如果 为None,则不会进行早停

'valid_loss'
early_stopping_min_delta float

早停中损失/指标的最小变化量, 符合改进条件

0.001
early_stopping_mode str

损失/指标应优化的方向.可选值为: [max,min].

'min'
early_stopping_patience int

在损失/指标没有进一步改善之前等待的周期数

3
early_stopping_kwargs Optional[Dict]

早停回调的额外关键字参数. 有关更多详细信息,请参阅PyTorch Lightning EarlyStopping回调的文档.

lambda: {}()
checkpoints Optional[str]

需要监控的损失/指标以进行检查点保存.如果为None, 则不会进行检查点保存

'valid_loss'
checkpoints_path str

保存模型的路径

'saved_models'
checkpoints_every_n_epochs int

检查点之间的训练步数

1
checkpoints_name Optional[str]

保存模型的名称.如果留空, 首先会查找experiment_config中的run_name,如果也为None,则使用 类似task_version的通用名称.

None
checkpoints_mode str

损失/指标应优化的方向

'min'
checkpoints_save_top_k int

保存的最佳模型数量

1
checkpoints_kwargs Optional[Dict]

检查点回调的额外关键字参数. 有关更多详细信息,请参阅PyTorch Lightning ModelCheckpoint回调的文档.

lambda: {}()
load_best bool

标志以加载训练期间保存的最佳模型

True
track_grad_norm int

在日志记录器中跟踪和记录梯度范数.默认值-1表示不跟踪. 1表示L1范数,2表示L2范数,依此类推.

-1
progress_bar str

进度条类型.可以是以下之一:none, simple, rich.默认为rich.

'rich'
precision int

模型的精度.可以是以下之一:32, 16, 64.默认为32. 可选值为:[32,16,64].

32
seed int

随机数生成器的种子.默认为42

42
trainer_kwargs Dict[str, Any]

传递给PyTorch Lightning Trainer的额外关键字参数.请参阅 https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.trainer.html#pytorch_lightning.trainer.Trainer

dict()
Source code in src/pytorch_tabular/config/config.py
@dataclass
class TrainerConfig:
    """训练器配置.

    Parameters:
        batch_size (int): 每个训练批次中的样本数量

        data_aware_init_batch_size (int): 数据感知初始化时每个训练批次中的样本数量,
            适用时默认值为2000

        fast_dev_run (bool): 如果设置为``n``(整数)则运行n个批次,如果设置为``True``则运行1个批次,
                用于查找训练、验证和测试中的任何错误(即一种单元测试).

        max_epochs (int): 要运行的最大周期数

        min_epochs (Optional[int]): 强制训练至少这些周期数.默认值为1

        max_time (Optional[int]): 经过此时间后停止训练.默认禁用(None)

        accelerator (Optional[str]): 用于训练的加速器.可以是以下之一:
                'cpu','gpu','tpu','ipu', 'mps', 'auto'.默认为'auto'.
                可选值为:[`cpu`,`gpu`,`tpu`,`ipu`,'mps',`auto`].

        devices (Optional[int]): 用于训练的设备数量(整数).-1表示使用所有可用设备.
                默认情况下,使用所有可用设备(-1)

        devices_list (Optional[List[int]]): 用于训练的设备列表(列表).如果指定,
                优先于`devices`参数.默认为None

        accumulate_grad_batches (int): 每k个批次或按字典设置累积梯度.训练器
                也会在最后一个不可整除的步骤数上调用optimizer.step().

        auto_lr_find (bool): 在调用trainer.tune()时运行学习率查找算法,
                以找到最佳初始学习率.

        auto_select_gpus (bool): 如果启用且`devices`为整数,则自动选择可用GPU.
                这在GPU配置为'独占模式'时特别有用,即一次只有一个进程可以访问它们.

        check_val_every_n_epoch (int): 每n个训练周期检查一次验证.

        gradient_clip_val (float): 梯度裁剪值

        overfit_batches (float): 使用训练集的此部分数据.如果不为零,将使用相同的
                训练集进行验证和测试.如果训练数据加载器的shuffle=True,Lightning
                将自动禁用它.对于快速调试或故意过拟合很有用.

        deterministic (bool): 如果为真,启用cudnn.deterministic.可能会使系统变慢,但
                确保可重复性.

        profiler (Optional[str]): 在训练期间分析各个步骤并协助识别
                瓶颈.可以是None、simple或advanced、pytorch.可选值为:
                [`None`,`simple`,`advanced`,`pytorch`].

        early_stopping (Optional[str]): 需要监控的损失/指标以进行早停.如果
                为None,则不会进行早停

        early_stopping_min_delta (float): 早停中损失/指标的最小变化量,
                符合改进条件

        early_stopping_mode (str): 损失/指标应优化的方向.可选值为:
                [`max`,`min`].

        early_stopping_patience (int): 在损失/指标没有进一步改善之前等待的周期数

        early_stopping_kwargs (Optional[Dict]): 早停回调的额外关键字参数.
                有关更多详细信息,请参阅PyTorch Lightning EarlyStopping回调的文档.

        checkpoints (Optional[str]): 需要监控的损失/指标以进行检查点保存.如果为None,
                则不会进行检查点保存

        checkpoints_path (str): 保存模型的路径

        checkpoints_every_n_epochs (int): 检查点之间的训练步数

        checkpoints_name (Optional[str]): 保存模型的名称.如果留空,
                首先会查找experiment_config中的`run_name`,如果也为None,则使用
                类似task_version的通用名称.

        checkpoints_mode (str): 损失/指标应优化的方向

        checkpoints_save_top_k (int): 保存的最佳模型数量

        checkpoints_kwargs (Optional[Dict]): 检查点回调的额外关键字参数.
                有关更多详细信息,请参阅PyTorch Lightning ModelCheckpoint回调的文档.

        load_best (bool): 标志以加载训练期间保存的最佳模型

        track_grad_norm (int): 在日志记录器中跟踪和记录梯度范数.默认值-1表示不跟踪.
                1表示L1范数,2表示L2范数,依此类推.

        progress_bar (str): 进度条类型.可以是以下之一:`none`, `simple`, `rich`.默认为`rich`.

        precision (int): 模型的精度.可以是以下之一:`32`, `16`, `64`.默认为`32`.
                可选值为:[`32`,`16`,`64`].

        seed (int): 随机数生成器的种子.默认为42

        trainer_kwargs (Dict[str, Any]): 传递给PyTorch Lightning Trainer的额外关键字参数.请参阅
                https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.trainer.html#pytorch_lightning.trainer.Trainer"""

    batch_size: int = field(default=64, metadata={"help": "Number of samples in each batch of training"})
    data_aware_init_batch_size: int = field(
        default=2000,
        metadata={
            "help": "Number of samples in each batch of training for the data-aware initialization,"
            " when applicable. Defaults to 2000"
        },
    )
    fast_dev_run: bool = field(
        default=False,
        metadata={
            "help": "runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es) of train,"
            " val and test to find any bugs (ie: a sort of unit test)."
        },
    )
    max_epochs: int = field(default=10, metadata={"help": "Maximum number of epochs to be run"})
    min_epochs: Optional[int] = field(
        default=1,
        metadata={"help": "Force training for at least these many epochs. 1 by default"},
    )
    max_time: Optional[int] = field(
        default=None,
        metadata={"help": "Stop training after this amount of time has passed. Disabled by default (None)"},
    )
    accelerator: Optional[str] = field(
        default="auto",
        metadata={
            "help": "The accelerator to use for training. Can be one of 'cpu','gpu','tpu','ipu','auto'."
            " Defaults to 'auto'",
            "choices": ["cpu", "gpu", "tpu", "ipu", "mps", "auto"],
        },
    )
    devices: Optional[int] = field(
        default=-1,
        metadata={
            "help": "Number of devices to train on. -1 uses all available devices."
            " By default uses all available devices (-1)",
        },
    )
    devices_list: Optional[List[int]] = field(
        default=None,
        metadata={
            "help": "List of devices to train on (list). If specified, takes precedence over `devices` argument."
            " Defaults to None",
        },
    )

    accumulate_grad_batches: int = field(
        default=1,
        metadata={
            "help": "Accumulates grads every k batches or as set up in the dict."
            " Trainer also calls optimizer.step() for the last indivisible step number."
        },
    )
    auto_lr_find: bool = field(
        default=False,
        metadata={
            "help": "Runs a learning rate finder algorithm (see this paper) when calling trainer.tune(),"
            " to find optimal initial learning rate."
        },
    )
    auto_select_gpus: bool = field(
        default=True,
        metadata={
            "help": "If enabled and `devices` is an integer, pick available gpus automatically."
            " This is especially useful when GPUs are configured to be in 'exclusive mode',"
            " such that only one process at a time can access them."
        },
    )
    check_val_every_n_epoch: int = field(default=1, metadata={"help": "Check val every n train epochs."})
    gradient_clip_val: float = field(default=0.0, metadata={"help": "Gradient clipping value"})
    overfit_batches: float = field(
        default=0.0,
        metadata={
            "help": "Uses this much data of the training set. If nonzero, will use the same training set"
            " for validation and testing. If the training dataloaders have shuffle=True,"
            " Lightning will automatically disable it."
            " Useful for quickly debugging or trying to overfit on purpose."
        },
    )
    deterministic: bool = field(
        default=False,
        metadata={
            "help": "If true enables cudnn.deterministic. Might make your system slower, but ensures reproducibility."
        },
    )
    profiler: Optional[str] = field(
        default=None,
        metadata={
            "help": "To profile individual steps during training and assist in identifying bottlenecks."
            " None, simple or advanced, pytorch",
            "choices": [None, "simple", "advanced", "pytorch"],
        },
    )
    early_stopping: Optional[str] = field(
        default="valid_loss",
        metadata={
            "help": "The loss/metric that needed to be monitored for early stopping."
            " If None, there will be no early stopping"
        },
    )
    early_stopping_min_delta: float = field(
        default=0.001,
        metadata={"help": "The minimum delta in the loss/metric which qualifies as an improvement in early stopping"},
    )
    early_stopping_mode: str = field(
        default="min",
        metadata={
            "help": "The direction in which the loss/metric should be optimized",
            "choices": ["max", "min"],
        },
    )
    early_stopping_patience: int = field(
        default=3,
        metadata={"help": "The number of epochs to wait until there is no further improvements in loss/metric"},
    )
    early_stopping_kwargs: Optional[Dict[str, Any]] = field(
        default_factory=lambda: {},
        metadata={
            "help": "Additional keyword arguments for the early stopping callback."
            " See the documentation for the PyTorch Lightning EarlyStopping callback for more details."
        },
    )
    checkpoints: Optional[str] = field(
        default="valid_loss",
        metadata={
            "help": "The loss/metric that needed to be monitored for checkpoints. If None, there will be no checkpoints"
        },
    )
    checkpoints_path: str = field(
        default="saved_models",
        metadata={"help": "The path where the saved models will be"},
    )
    checkpoints_every_n_epochs: int = field(
        default=1,
        metadata={"help": "Number of training steps between checkpoints"},
    )
    checkpoints_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "The name under which the models will be saved. If left blank,"
            " first it will look for `run_name` in experiment_config and if that is also None"
            " then it will use a generic name like task_version."
        },
    )
    checkpoints_mode: str = field(
        default="min",
        metadata={"help": "The direction in which the loss/metric should be optimized"},
    )
    checkpoints_save_top_k: int = field(
        default=1,
        metadata={"help": "The number of best models to save"},
    )
    checkpoints_kwargs: Optional[Dict[str, Any]] = field(
        default_factory=lambda: {},
        metadata={
            "help": "Additional keyword arguments for the checkpoints callback. See the documentation"
            " for the PyTorch Lightning ModelCheckpoint callback for more details."
        },
    )
    load_best: bool = field(
        default=True,
        metadata={"help": "Flag to load the best model saved during training"},
    )
    track_grad_norm: int = field(
        default=-1,
        metadata={
            "help": "Track and Log Gradient Norms in the logger. -1 by default means no tracking. "
            "1 for the L1 norm, 2 for L2 norm, etc."
        },
    )
    progress_bar: str = field(
        default="rich",
        metadata={"help": "Progress bar type. Can be one of: `none`, `simple`, `rich`. Defaults to `rich`."},
    )
    precision: int = field(
        default=32,
        metadata={
            "help": "Precision of the model. Can be one of: `32`, `16`, `64`. Defaults to `32`.",
            "choices": [32, 16, 64],
        },
    )
    seed: int = field(
        default=42,
        metadata={"help": "Seed for random number generators. Defaults to 42"},
    )
    trainer_kwargs: Dict[str, Any] = field(
        default_factory=dict,
        metadata={"help": "Additional kwargs to be passed to PyTorch Lightning Trainer."},
    )

    def __post_init__(self):
        _validate_choices(self)
        if self.accelerator is None:
            self.accelerator = "cpu"
        if self.devices_list is not None:
            self.devices = self.devices_list
        delattr(self, "devices_list")
        for key in self.early_stopping_kwargs.keys():
            if key in ["min_delta", "mode", "patience"]:
                raise ValueError(
                    f"Cannot override {key} in early_stopping_kwargs."
                    f" Please use the appropriate argument in `TrainerConfig`"
                )
        for key in self.checkpoints_kwargs.keys():
            if key in ["dirpath", "filename", "monitor", "save_top_k", "mode", "every_n_epochs"]:
                raise ValueError(
                    f"Cannot override {key} in checkpoints_kwargs."
                    f" Please use the appropriate argument in `TrainerConfig`"
                )