Skip to content

优化器与学习率调度器

优化器是梯度下降过程的核心,也是我们需要训练出一个好模型的关键组件。Pytorch Tabular 默认使用 Adam 优化器,学习率为 1e-3。这主要是因为经验法则提供了一个良好的起点。

有时,学习率调度器让你在优化过程中对学习率的使用有更精细的控制。默认情况下,PyTorch Tabular 不应用任何学习率调度器。

基本用法

  • optimizer: str: 来自 torch.optim 的标准优化器之一。默认为 Adam
  • optimizer_params: Dict: 优化器的参数。如果留空,将使用默认参数。
  • lr_scheduler: str: 要使用的学习率调度器的名称,如果有的话,来自 torch.optim.lr_scheduler。如果为 None,则不使用任何调度器。默认为 None
  • lr_scheduler_params: Dict: 学习率调度器的参数。如果留空,将使用默认参数。
  • lr_scheduler_monitor_metric: str: 与 ReduceLROnPlateau 一起使用,其中平台是根据此指标决定的。默认为 val_loss

使用示例

optimizer_config = OptimizerConfig(
    optimizer="RMSprop", lr_scheduler="StepLR", lr_scheduler_params={"step_size": 10}
)

高级用法

虽然 Config 对象限制你只能使用 torch.optim 中的标准优化器和学习率调度器,但你可以使用任何自定义的优化器或学习率调度器,只要它们是标准优化器和调度器的直接替代品。你可以通过 TabularModelfit 方法来实现这一点,该方法允许你覆盖通过配置设置的优化器和学习率。

使用示例

from torch_optimizer import QHAdam

tabular_model.fit(
    train=train,
    validation=val,
    optimizer=QHAdam,
    optimizer_params={"nus": (0.7, 1.0), "betas": (0.95, 0.998)},
)

pytorch_tabular.config.OptimizerConfig dataclass

优化器和学习率调度器配置.

Parameters:

Name Type Description Default
optimizer str

来自 torch.optim 的标准优化器之一, 或提供完整的 Python 路径,例如 "torch_optimizer.RAdam".

'Adam'
optimizer_params Dict

优化器的参数.如果留空,将使用默认参数.

lambda: {}()
lr_scheduler Optional[str]

要使用的学习率调度器的名称,如果有的话,来自 torch.optim.lr_scheduler.如果为 None,将不使用任何调度器.默认为 None.

None
lr_scheduler_params Optional[Dict]

学习率调度器的参数.如果留空,将使用默认参数.

lambda: {}()
lr_scheduler_monitor_metric Optional[str]

与 ReduceLROnPlateau 一起使用,其中平台是基于此指标决定的.

'valid_loss'
Source code in src/pytorch_tabular/config/config.py
@dataclass
class OptimizerConfig:
    """优化器和学习率调度器配置.

    Parameters:
        optimizer (str): 来自 [torch.optim](https://pytorch.org/docs/stable/optim.html#algorithms) 的标准优化器之一,
                或提供完整的 Python 路径,例如 "torch_optimizer.RAdam".

        optimizer_params (Dict): 优化器的参数.如果留空,将使用默认参数.

        lr_scheduler (Optional[str]): 要使用的学习率调度器的名称,如果有的话,来自
                [torch.optim.lr_scheduler](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-
                rate).如果为 None,将不使用任何调度器.默认为 `None`.

        lr_scheduler_params (Optional[Dict]): 学习率调度器的参数.如果留空,将使用默认参数.

        lr_scheduler_monitor_metric (Optional[str]): 与 ReduceLROnPlateau 一起使用,其中平台是基于此指标决定的."""

    optimizer: str = field(
        default="Adam",
        metadata={
            "help": "Any of the standard optimizers from"
            " [torch.optim](https://pytorch.org/docs/stable/optim.html#algorithms) or provide full python path,"
            " for example 'torch_optimizer.RAdam'."
        },
    )
    optimizer_params: Dict = field(
        default_factory=lambda: {},
        metadata={"help": "The parameters for the optimizer. If left blank, will use default parameters."},
    )
    lr_scheduler: Optional[str] = field(
        default=None,
        metadata={
            "help": "The name of the LearningRateScheduler to use, if any, from"
            " https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate."
            " If None, will not use any scheduler. Defaults to `None`",
        },
    )
    lr_scheduler_params: Optional[Dict] = field(
        default_factory=lambda: {},
        metadata={"help": "The parameters for the LearningRateScheduler. If left blank, will use default parameters."},
    )

    lr_scheduler_monitor_metric: Optional[str] = field(
        default="valid_loss",
        metadata={"help": "Used with ReduceLROnPlateau, where the plateau is decided based on this metric"},
    )

    @staticmethod
    def read_from_yaml(filename: str = "config/optimizer_config.yml"):
        config = _read_yaml(filename)
        if config["lr_scheduler_params"] is None:
            config["lr_scheduler_params"] = {}
        return OptimizerConfig(**config)