Skip to content

配置

核心配置

数据配置.

Parameters:

Name Type Description Default
target Optional[List[str]]

目标列名称的字符串列表.对于SSL任务以外的所有任务都是必需的.

None
continuous_cols List

数值字段的列名称.默认为 []

list()
categorical_cols List

需要特殊处理的分类字段的列名称.默认为 []

list()
date_columns List

(列名, 频率, 格式) 元组形式的日期字段.例如,名为 introduction_date 且频率为 "2023-12" 的月度字段应有一个条目 ('intro_date','M','%Y-%m')

list()
encode_date_columns bool

是否对从日期派生的变量进行编码

True
validation_split Optional[float]

保留为验证集的训练行百分比.仅在未单独提供验证数据时使用

0.2
continuous_feature_transform Optional[str]

是否在建模前对特征进行变换.默认关闭.可选值为: [None,yeo-johnson,box-cox, quantile_normal,quantile_uniform].

None
normalize_continuous_features bool

是否对输入特征(连续)进行归一化

True
quantile_noise int

未实现.如果指定,将在数据上拟合 QuantileTransformer,并添加标准差为 :quantile_noise: * data.std 的高斯噪声;这将使离散值更易分离.请注意,此变换不会对结果数据应用高斯噪声,噪声仅用于 QuantileTransformer

0
num_workers Optional[int]

用于数据加载的工作线程数.对于 Windows 系统始终设置为 0

0
pin_memory bool

是否为数据加载固定内存

True
handle_unknown_categories bool

是否将分类列中的未知或新值处理为未知

True
handle_missing_values bool

是否将分类列中的缺失值处理为未知

True
Source code in src/pytorch_tabular/config/config.py
@dataclass
class DataConfig:
    """数据配置.

    Parameters:
        target (Optional[List[str]]): 目标列名称的字符串列表.对于SSL任务以外的所有任务都是必需的.

        continuous_cols (List): 数值字段的列名称.默认为 []

        categorical_cols (List): 需要特殊处理的分类字段的列名称.默认为 []

        date_columns (List): (列名, 频率, 格式) 元组形式的日期字段.例如,名为 introduction_date 且频率为 "2023-12" 的月度字段应有一个条目 ('intro_date','M','%Y-%m')

        encode_date_columns (bool): 是否对从日期派生的变量进行编码

        validation_split (Optional[float]): 保留为验证集的训练行百分比.仅在未单独提供验证数据时使用

        continuous_feature_transform (Optional[str]): 是否在建模前对特征进行变换.默认关闭.可选值为: [`None`,`yeo-johnson`,`box-cox`, `quantile_normal`,`quantile_uniform`].

        normalize_continuous_features (bool): 是否对输入特征(连续)进行归一化

        quantile_noise (int): 未实现.如果指定,将在数据上拟合 QuantileTransformer,并添加标准差为 :quantile_noise: * data.std 的高斯噪声;这将使离散值更易分离.请注意,此变换不会对结果数据应用高斯噪声,噪声仅用于 QuantileTransformer

        num_workers (Optional[int]): 用于数据加载的工作线程数.对于 Windows 系统始终设置为 0

        pin_memory (bool): 是否为数据加载固定内存

        handle_unknown_categories (bool): 是否将分类列中的未知或新值处理为未知

        handle_missing_values (bool): 是否将分类列中的缺失值处理为未知"""

    target: Optional[List[str]] = field(
        default=None,
        metadata={
            "help": "A list of strings with the names of the target column(s)."
            " It is mandatory for all except SSL tasks."
        },
    )
    continuous_cols: List = field(
        default_factory=list,
        metadata={"help": "Column names of the numeric fields. Defaults to []"},
    )
    categorical_cols: List = field(
        default_factory=list,
        metadata={"help": "Column names of the categorical fields to treat differently. Defaults to []"},
    )
    date_columns: List = field(
        default_factory=list,
        metadata={
            "help": "(Column names, Freq) tuples of the date fields. For eg. a field named"
            " introduction_date and with a monthly frequency like '2023-12' should have"
            " an entry ('intro_date','M','%Y-%m')"
        },
    )

    encode_date_columns: bool = field(
        default=True,
        metadata={"help": "Whether or not to encode the derived variables from date"},
    )
    validation_split: Optional[float] = field(
        default=0.2,
        metadata={
            "help": "Percentage of Training rows to keep aside as validation."
            " Used only if Validation Data is not given separately"
        },
    )
    continuous_feature_transform: Optional[str] = field(
        default=None,
        metadata={
            "help": "Whether or not to transform the features before modelling. By default it is turned off.",
            "choices": [
                None,
                "yeo-johnson",
                "box-cox",
                "quantile_normal",
                "quantile_uniform",
            ],
        },
    )
    normalize_continuous_features: bool = field(
        default=True,
        metadata={"help": "Flag to normalize the input features (continuous)"},
    )
    quantile_noise: int = field(
        default=0,
        metadata={
            "help": "NOT IMPLEMENTED. If specified fits QuantileTransformer on data with added gaussian noise"
            " with std = :quantile_noise: * data.std ; this will cause discrete values to be more separable."
            " Please not that this transformation does NOT apply gaussian noise to the resulting data,"
            " the noise is only applied for QuantileTransformer"
        },
    )
    num_workers: Optional[int] = field(
        default=0,
        metadata={"help": "The number of workers used for data loading. For windows always set to 0"},
    )
    pin_memory: bool = field(
        default=True,
        metadata={"help": "Whether or not to pin memory for data loading."},
    )
    handle_unknown_categories: bool = field(
        default=True,
        metadata={"help": "Whether or not to handle unknown or new values in categorical columns as unknown"},
    )
    handle_missing_values: bool = field(
        default=True,
        metadata={"help": "Whether or not to handle missing values in categorical columns as unknown"},
    )

    def __post_init__(self):
        assert (
            len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0
        ), "There should be at-least one feature defined in categorical, continuous, or date columns"
        _validate_choices(self)
        if os.name == "nt" and self.num_workers != 0:
            print("Windows does not support num_workers > 0. Setting num_workers to 0")
            self.num_workers = 0

基础模型配置.

Parameters:

Name Type Description Default
task str

指定问题是回归还是分类.backbone 是一种将模型视为生成特征的主干的任务.主要用于内部SSL及相关任务.可选值为:[regression,classification,backbone].

required
head Optional[str]

模型使用的头部.应为 pytorch_tabular.models.common.heads 中定义的头部之一.默认为 LinearHead.可选值为:[None,LinearHead,MixtureDensityHead].

'LinearHead'
head_config Optional[Dict]

定义头部的配置字典.如果留空,将初始化为默认的线性头部.

lambda: {'layers': ''}()
embedding_dims Optional[List]

每个分类列的嵌入维度列表,格式为 (基数, 嵌入维度).如果留空,将根据分类列的基数推断,规则为 min(50, (x + 1) // 2).

None
embedding_dropout float

应用于分类嵌入的丢弃率.默认为 0.0.

0.0
batch_norm_continuous_input bool

如果为 True,将通过 BatchNorm 层对连续层进行归一化.

True
virtual_batch_size Optional[int]

如果不为 None,所有 BatchNorm 将被转换为 GhostBatchNorm,并指定虚拟批量大小.默认为 None.

None
learning_rate float

模型的学习率.默认为 1e-3.

0.001
loss Optional[str]

应用的损失函数.默认情况下,回归为 MSELoss,分类为 CrossEntropyLoss.除非你确定自己在做什么,否则请保持为 MSELoss 或 L1Loss 用于回归,CrossEntropyLoss 用于分类.

None
metrics Optional[List[str]]

训练期间需要跟踪的指标列表.指标应为 torchmetrics 中实现的功能性指标之一.默认情况下,分类为 accuracy,回归为 mean_squared_error.

None
metrics_prob_input Optional[bool]

配置中定义的分类指标的强制参数.定义指标函数的输入是概率还是类别.长度应与指标数量相同.默认为 None.

None
metrics_params Optional[List]

传递给指标函数的参数.task 强制为 multiclass,因为多分类版本可以处理二分类,并且为了简化,我们仅使用 multiclass.

None
target_range Optional[List]

限制输出变量的范围.当前忽略多目标回归.通常用于回归问题.如果留空,将不应用任何限制.

None
seed int

用于可重复性的种子.默认为 42.

42
Source code in src/pytorch_tabular/config/config.py
@dataclass
class ModelConfig:
    """基础模型配置.

    Parameters:
        task (str): 指定问题是回归还是分类.`backbone` 是一种将模型视为生成特征的主干的任务.主要用于内部SSL及相关任务.可选值为:[`regression`,`classification`,`backbone`].

        head (Optional[str]): 模型使用的头部.应为 `pytorch_tabular.models.common.heads` 中定义的头部之一.默认为 LinearHead.可选值为:[`None`,`LinearHead`,`MixtureDensityHead`].

        head_config (Optional[Dict]): 定义头部的配置字典.如果留空,将初始化为默认的线性头部.

        embedding_dims (Optional[List]): 每个分类列的嵌入维度列表,格式为 (基数, 嵌入维度).如果留空,将根据分类列的基数推断,规则为 min(50, (x + 1) // 2).

        embedding_dropout (float): 应用于分类嵌入的丢弃率.默认为 0.0.

        batch_norm_continuous_input (bool): 如果为 True,将通过 BatchNorm 层对连续层进行归一化.

        virtual_batch_size (Optional[int]): 如果不为 None,所有 BatchNorm 将被转换为 GhostBatchNorm,并指定虚拟批量大小.默认为 None.

        learning_rate (float): 模型的学习率.默认为 1e-3.

        loss (Optional[str]): 应用的损失函数.默认情况下,回归为 MSELoss,分类为 CrossEntropyLoss.除非你确定自己在做什么,否则请保持为 MSELoss 或 L1Loss 用于回归,CrossEntropyLoss 用于分类.

        metrics (Optional[List[str]]): 训练期间需要跟踪的指标列表.指标应为 ``torchmetrics`` 中实现的功能性指标之一.默认情况下,分类为 accuracy,回归为 mean_squared_error.

        metrics_prob_input (Optional[bool]): 配置中定义的分类指标的强制参数.定义指标函数的输入是概率还是类别.长度应与指标数量相同.默认为 None.

        metrics_params (Optional[List]): 传递给指标函数的参数.`task` 强制为 `multiclass`,因为多分类版本可以处理二分类,并且为了简化,我们仅使用 `multiclass`.

        target_range (Optional[List]): 限制输出变量的范围.当前忽略多目标回归.通常用于回归问题.如果留空,将不应用任何限制.

        seed (int): 用于可重复性的种子.默认为 42."""

    task: str = field(
        metadata={
            "help": "Specify whether the problem is regression or classification."
            " `backbone` is a task which considers the model as a backbone to generate features."
            " Mostly used internally for SSL and related tasks.",
            "choices": ["regression", "classification", "backbone"],
        }
    )

    head: Optional[str] = field(
        default="LinearHead",
        metadata={
            "help": "The head to be used for the model. Should be one of the heads defined"
            " in `pytorch_tabular.models.common.heads`. Defaults to  LinearHead",
            "choices": [None, "LinearHead", "MixtureDensityHead"],
        },
    )

    head_config: Optional[Dict] = field(
        default_factory=lambda: {"layers": ""},
        metadata={
            "help": "The config as a dict which defines the head."
            " If left empty, will be initialized as default linear head."
        },
    )
    embedding_dims: Optional[List] = field(
        default=None,
        metadata={
            "help": "The dimensions of the embedding for each categorical column as a list of tuples "
            "(cardinality, embedding_dim). If left empty, will infer using the cardinality of the "
            "categorical column using the rule min(50, (x + 1) // 2)"
        },
    )
    embedding_dropout: float = field(
        default=0.0,
        metadata={"help": "Dropout to be applied to the Categorical Embedding. Defaults to 0.0"},
    )
    batch_norm_continuous_input: bool = field(
        default=True,
        metadata={"help": "If True, we will normalize the continuous layer by passing it through a BatchNorm layer."},
    )

    learning_rate: float = field(
        default=1e-3,
        metadata={"help": "The learning rate of the model. Defaults to 1e-3."},
    )
    loss: Optional[str] = field(
        default=None,
        metadata={
            "help": "The loss function to be applied. By Default it is MSELoss for regression "
            "and CrossEntropyLoss for classification. Unless you are sure what you are doing, "
            "leave it at MSELoss or L1Loss for regression and CrossEntropyLoss for classification"
        },
    )
    metrics: Optional[List[str]] = field(
        default=None,
        metadata={
            "help": "the list of metrics you need to track during training. The metrics should be one "
            "of the functional metrics implemented in ``torchmetrics``. To use your own metric, please "
            "use the `metric` param in the `fit` method By default, it is accuracy if classification "
            "and mean_squared_error for regression"
        },
    )
    metrics_prob_input: Optional[List[bool]] = field(
        default=None,
        metadata={
            "help": "Is a mandatory parameter for classification metrics defined in the config. This defines "
            "whether the input to the metric function is the probability or the class. Length should be same "
            "as the number of metrics. Defaults to None."
        },
    )
    metrics_params: Optional[List] = field(
        default=None,
        metadata={
            "help": "The parameters to be passed to the metrics function. `task` is forced to be `multiclass`` "
            "because the multiclass version can handle binary as well and for simplicity we are only using "
            "`multiclass`."
        },
    )
    target_range: Optional[List] = field(
        default=None,
        metadata={
            "help": "The range in which we should limit the output variable. "
            "Currently ignored for multi-target regression. Typically used for Regression problems. "
            "If left empty, will not apply any restrictions"
        },
    )

    virtual_batch_size: Optional[int] = field(
        default=None,
        metadata={
            "help": "If not None, all BatchNorms will be converted to GhostBatchNorm's "
            " with this virtual batch size. Defaults to None"
        },
    )

    seed: int = field(
        default=42,
        metadata={"help": "The seed for reproducibility. Defaults to 42"},
    )

    _module_src: str = field(default="models")
    _model_name: str = field(default="Model")
    _backbone_name: str = field(default="Backbone")
    _config_name: str = field(default="Config")

    def __post_init__(self):
        if self.task == "regression":
            self.loss = self.loss or "MSELoss"
            self.metrics = self.metrics or ["mean_squared_error"]
            self.metrics_params = [{} for _ in self.metrics] if self.metrics_params is None else self.metrics_params
            self.metrics_prob_input = [False for _ in self.metrics]  # not used in Regression. just for compatibility
        elif self.task == "classification":
            self.loss = self.loss or "CrossEntropyLoss"
            self.metrics = self.metrics or ["accuracy"]
            self.metrics_params = [{} for _ in self.metrics] if self.metrics_params is None else self.metrics_params
            self.metrics_prob_input = (
                [False for _ in self.metrics] if self.metrics_prob_input is None else self.metrics_prob_input
            )
        elif self.task == "backbone":
            self.loss = None
            self.metrics = None
            self.metrics_params = None
            if self.head is not None:
                logger.warning("`head` is not a valid parameter for backbone task. Making `head=None`")
                self.head = None
                self.head_config = None
        else:
            raise NotImplementedError(
                f"{self.task} is not a valid task. Should be one of "
                f"{self.__dataclass_fields__['task'].metadata['choices']}"
            )
        if self.metrics is not None:
            assert len(self.metrics) == len(self.metrics_params), "metrics and metric_params should have same length"

        if self.task != "backbone":
            assert self.head in dir(heads.blocks), f"{self.head} is not a valid head"
            if hasattr(self, "_config_name") and self._config_name != "MDNConfig":
                assert self.head != "MixtureDensityHead", "MixtureDensityHead is not supported as a head for regular "
                "models. Use `MDNConfig` instead. Please see Probabilistic Regression with MDN How-to-Guide in "
                "documentation for the right usage."
            _head_callable = getattr(heads.blocks, self.head)
            ideal_head_config = _head_callable._config_template
            invalid_keys = set(self.head_config.keys()) - set(ideal_head_config.__dict__.keys())
            assert len(invalid_keys) == 0, f"`head_config` has some invalid keys: {invalid_keys}"

        # For Custom models, setting these values for compatibility
        if not hasattr(self, "_config_name"):
            self._config_name = type(self).__name__
        if not hasattr(self, "_model_name"):
            self._model_name = re.sub("[Cc]onfig", "Model", self._config_name)
        if not hasattr(self, "_backbone_name"):
            self._backbone_name = re.sub("[Cc]onfig", "Backbone", self._config_name)
        _validate_choices(self)

基础 SSLModel 配置.

Parameters:

Name Type Description Default
encoder_config Optional[ModelConfig]

用于模型的编码器的配置.应为 PyTorch Tabular 中定义的模型配置之一.

None
decoder_config Optional[ModelConfig]

用于模型的解码器的配置.应为 PyTorch Tabular 中定义的模型配置之一.默认为 nn.Identity.

None
embedding_dims Optional[List]

每个分类列的嵌入维度,以元组列表 (基数, 嵌入维度) 的形式表示.如果留空,将根据分类列的基数推断,使用规则 min(50, (x + 1) // 2).

None
embedding_dropout float

应用于分类嵌入的 dropout.默认为 0.1.

0.1
batch_norm_continuous_input bool

如果为 True,我们将通过 BatchNorm 层对连续层进行归一化.

True
virtual_batch_size Optional[int]

如果不为 None,所有 BatchNorm 将被转换为具有指定虚拟批量大小的 GhostBatchNorm.默认为 None.

None
learning_rate float

模型的学习率.默认为 1e-3.

0.001
seed int

用于可重复性的种子.默认为 42.

42
Source code in src/pytorch_tabular/config/config.py
@dataclass
class SSLModelConfig:
    """基础 SSLModel 配置.

    Parameters:
        encoder_config (Optional[ModelConfig]): 用于模型的编码器的配置.应为 PyTorch Tabular 中定义的模型配置之一.

        decoder_config (Optional[ModelConfig]): 用于模型的解码器的配置.应为 PyTorch Tabular 中定义的模型配置之一.默认为 nn.Identity.

        embedding_dims (Optional[List]): 每个分类列的嵌入维度,以元组列表 (基数, 嵌入维度) 的形式表示.如果留空,将根据分类列的基数推断,使用规则 min(50, (x + 1) // 2).

        embedding_dropout (float): 应用于分类嵌入的 dropout.默认为 0.1.

        batch_norm_continuous_input (bool): 如果为 True,我们将通过 BatchNorm 层对连续层进行归一化.

        virtual_batch_size (Optional[int]): 如果不为 None,所有 BatchNorm 将被转换为具有指定虚拟批量大小的 GhostBatchNorm.默认为 None.

        learning_rate (float): 模型的学习率.默认为 1e-3.

        seed (int): 用于可重复性的种子.默认为 42."""

    task: str = field(init=False, default="ssl")

    encoder_config: Optional[ModelConfig] = field(
        default=None,
        metadata={
            "help": "The config of the encoder to be used for the model."
            " Should be one of the model configs defined in PyTorch Tabular",
        },
    )

    decoder_config: Optional[ModelConfig] = field(
        default=None,
        metadata={
            "help": "The config of decoder to be used for the model."
            " Should be one of the model configs defined in PyTorch Tabular. Defaults to nn.Identity",
        },
    )

    embedding_dims: Optional[List] = field(
        default=None,
        metadata={
            "help": "The dimensions of the embedding for each categorical column as a list of tuples "
            "(cardinality, embedding_dim). If left empty, will infer using the cardinality of the "
            "categorical column using the rule min(50, (x + 1) // 2)"
        },
    )
    embedding_dropout: float = field(
        default=0.1,
        metadata={"help": "Dropout to be applied to the Categorical Embedding. Defaults to 0.1"},
    )
    batch_norm_continuous_input: bool = field(
        default=True,
        metadata={"help": "If True, we will normalize the continuous layer by passing it through a BatchNorm layer."},
    )
    virtual_batch_size: Optional[int] = field(
        default=None,
        metadata={
            "help": "If not None, all BatchNorms will be converted to GhostBatchNorm's "
            " with this virtual batch size. Defaults to None"
        },
    )
    learning_rate: float = field(
        default=1e-3,
        metadata={"help": "The learning rate of the model. Defaults to 1e-3"},
    )
    seed: int = field(
        default=42,
        metadata={"help": "The seed for reproducibility. Defaults to 42"},
    )

    _module_src: str = field(default="models")
    _model_name: str = field(default="Model")
    _config_name: str = field(default="Config")

    def __post_init__(self):
        assert self.task == "ssl", f"task should be ssl, got {self.task}"
        # For Custom models, setting these values for compatibility
        if not hasattr(self, "_config_name"):
            self._config_name = type(self).__name__
        if not hasattr(self, "_model_name"):
            self._model_name = re.sub("[Cc]onfig", "Model", self._config_name)
        _validate_choices(self)

训练器配置.

Parameters:

Name Type Description Default
batch_size int

每个训练批次中的样本数量

64
data_aware_init_batch_size int

数据感知初始化时每个训练批次中的样本数量, 适用时默认值为2000

2000
fast_dev_run bool

如果设置为n(整数)则运行n个批次,如果设置为True则运行1个批次, 用于查找训练、验证和测试中的任何错误(即一种单元测试).

False
max_epochs int

要运行的最大周期数

10
min_epochs Optional[int]

强制训练至少这些周期数.默认值为1

1
max_time Optional[int]

经过此时间后停止训练.默认禁用(None)

None
accelerator Optional[str]

用于训练的加速器.可以是以下之一: 'cpu','gpu','tpu','ipu', 'mps', 'auto'.默认为'auto'. 可选值为:[cpu,gpu,tpu,ipu,'mps',auto].

'auto'
devices Optional[int]

用于训练的设备数量(整数).-1表示使用所有可用设备. 默认情况下,使用所有可用设备(-1)

-1
devices_list Optional[List[int]]

用于训练的设备列表(列表).如果指定, 优先于devices参数.默认为None

None
accumulate_grad_batches int

每k个批次或按字典设置累积梯度.训练器 也会在最后一个不可整除的步骤数上调用optimizer.step().

1
auto_lr_find bool

在调用trainer.tune()时运行学习率查找算法, 以找到最佳初始学习率.

False
auto_select_gpus bool

如果启用且devices为整数,则自动选择可用GPU. 这在GPU配置为'独占模式'时特别有用,即一次只有一个进程可以访问它们.

True
check_val_every_n_epoch int

每n个训练周期检查一次验证.

1
gradient_clip_val float

梯度裁剪值

0.0
overfit_batches float

使用训练集的此部分数据.如果不为零,将使用相同的 训练集进行验证和测试.如果训练数据加载器的shuffle=True,Lightning 将自动禁用它.对于快速调试或故意过拟合很有用.

0.0
deterministic bool

如果为真,启用cudnn.deterministic.可能会使系统变慢,但 确保可重复性.

False
profiler Optional[str]

在训练期间分析各个步骤并协助识别 瓶颈.可以是None、simple或advanced、pytorch.可选值为: [None,simple,advanced,pytorch].

None
early_stopping Optional[str]

需要监控的损失/指标以进行早停.如果 为None,则不会进行早停

'valid_loss'
early_stopping_min_delta float

早停中损失/指标的最小变化量, 符合改进条件

0.001
early_stopping_mode str

损失/指标应优化的方向.可选值为: [max,min].

'min'
early_stopping_patience int

在损失/指标没有进一步改善之前等待的周期数

3
early_stopping_kwargs Optional[Dict]

早停回调的额外关键字参数. 有关更多详细信息,请参阅PyTorch Lightning EarlyStopping回调的文档.

lambda: {}()
checkpoints Optional[str]

需要监控的损失/指标以进行检查点保存.如果为None, 则不会进行检查点保存

'valid_loss'
checkpoints_path str

保存模型的路径

'saved_models'
checkpoints_every_n_epochs int

检查点之间的训练步数

1
checkpoints_name Optional[str]

保存模型的名称.如果留空, 首先会查找experiment_config中的run_name,如果也为None,则使用 类似task_version的通用名称.

None
checkpoints_mode str

损失/指标应优化的方向

'min'
checkpoints_save_top_k int

保存的最佳模型数量

1
checkpoints_kwargs Optional[Dict]

检查点回调的额外关键字参数. 有关更多详细信息,请参阅PyTorch Lightning ModelCheckpoint回调的文档.

lambda: {}()
load_best bool

标志以加载训练期间保存的最佳模型

True
track_grad_norm int

在日志记录器中跟踪和记录梯度范数.默认值-1表示不跟踪. 1表示L1范数,2表示L2范数,依此类推.

-1
progress_bar str

进度条类型.可以是以下之一:none, simple, rich.默认为rich.

'rich'
precision int

模型的精度.可以是以下之一:32, 16, 64.默认为32. 可选值为:[32,16,64].

32
seed int

随机数生成器的种子.默认为42

42
trainer_kwargs Dict[str, Any]

传递给PyTorch Lightning Trainer的额外关键字参数.请参阅 https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.trainer.html#pytorch_lightning.trainer.Trainer

dict()
Source code in src/pytorch_tabular/config/config.py
@dataclass
class TrainerConfig:
    """训练器配置.

    Parameters:
        batch_size (int): 每个训练批次中的样本数量

        data_aware_init_batch_size (int): 数据感知初始化时每个训练批次中的样本数量,
            适用时默认值为2000

        fast_dev_run (bool): 如果设置为``n``(整数)则运行n个批次,如果设置为``True``则运行1个批次,
                用于查找训练、验证和测试中的任何错误(即一种单元测试).

        max_epochs (int): 要运行的最大周期数

        min_epochs (Optional[int]): 强制训练至少这些周期数.默认值为1

        max_time (Optional[int]): 经过此时间后停止训练.默认禁用(None)

        accelerator (Optional[str]): 用于训练的加速器.可以是以下之一:
                'cpu','gpu','tpu','ipu', 'mps', 'auto'.默认为'auto'.
                可选值为:[`cpu`,`gpu`,`tpu`,`ipu`,'mps',`auto`].

        devices (Optional[int]): 用于训练的设备数量(整数).-1表示使用所有可用设备.
                默认情况下,使用所有可用设备(-1)

        devices_list (Optional[List[int]]): 用于训练的设备列表(列表).如果指定,
                优先于`devices`参数.默认为None

        accumulate_grad_batches (int): 每k个批次或按字典设置累积梯度.训练器
                也会在最后一个不可整除的步骤数上调用optimizer.step().

        auto_lr_find (bool): 在调用trainer.tune()时运行学习率查找算法,
                以找到最佳初始学习率.

        auto_select_gpus (bool): 如果启用且`devices`为整数,则自动选择可用GPU.
                这在GPU配置为'独占模式'时特别有用,即一次只有一个进程可以访问它们.

        check_val_every_n_epoch (int): 每n个训练周期检查一次验证.

        gradient_clip_val (float): 梯度裁剪值

        overfit_batches (float): 使用训练集的此部分数据.如果不为零,将使用相同的
                训练集进行验证和测试.如果训练数据加载器的shuffle=True,Lightning
                将自动禁用它.对于快速调试或故意过拟合很有用.

        deterministic (bool): 如果为真,启用cudnn.deterministic.可能会使系统变慢,但
                确保可重复性.

        profiler (Optional[str]): 在训练期间分析各个步骤并协助识别
                瓶颈.可以是None、simple或advanced、pytorch.可选值为:
                [`None`,`simple`,`advanced`,`pytorch`].

        early_stopping (Optional[str]): 需要监控的损失/指标以进行早停.如果
                为None,则不会进行早停

        early_stopping_min_delta (float): 早停中损失/指标的最小变化量,
                符合改进条件

        early_stopping_mode (str): 损失/指标应优化的方向.可选值为:
                [`max`,`min`].

        early_stopping_patience (int): 在损失/指标没有进一步改善之前等待的周期数

        early_stopping_kwargs (Optional[Dict]): 早停回调的额外关键字参数.
                有关更多详细信息,请参阅PyTorch Lightning EarlyStopping回调的文档.

        checkpoints (Optional[str]): 需要监控的损失/指标以进行检查点保存.如果为None,
                则不会进行检查点保存

        checkpoints_path (str): 保存模型的路径

        checkpoints_every_n_epochs (int): 检查点之间的训练步数

        checkpoints_name (Optional[str]): 保存模型的名称.如果留空,
                首先会查找experiment_config中的`run_name`,如果也为None,则使用
                类似task_version的通用名称.

        checkpoints_mode (str): 损失/指标应优化的方向

        checkpoints_save_top_k (int): 保存的最佳模型数量

        checkpoints_kwargs (Optional[Dict]): 检查点回调的额外关键字参数.
                有关更多详细信息,请参阅PyTorch Lightning ModelCheckpoint回调的文档.

        load_best (bool): 标志以加载训练期间保存的最佳模型

        track_grad_norm (int): 在日志记录器中跟踪和记录梯度范数.默认值-1表示不跟踪.
                1表示L1范数,2表示L2范数,依此类推.

        progress_bar (str): 进度条类型.可以是以下之一:`none`, `simple`, `rich`.默认为`rich`.

        precision (int): 模型的精度.可以是以下之一:`32`, `16`, `64`.默认为`32`.
                可选值为:[`32`,`16`,`64`].

        seed (int): 随机数生成器的种子.默认为42

        trainer_kwargs (Dict[str, Any]): 传递给PyTorch Lightning Trainer的额外关键字参数.请参阅
                https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.trainer.html#pytorch_lightning.trainer.Trainer"""

    batch_size: int = field(default=64, metadata={"help": "Number of samples in each batch of training"})
    data_aware_init_batch_size: int = field(
        default=2000,
        metadata={
            "help": "Number of samples in each batch of training for the data-aware initialization,"
            " when applicable. Defaults to 2000"
        },
    )
    fast_dev_run: bool = field(
        default=False,
        metadata={
            "help": "runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es) of train,"
            " val and test to find any bugs (ie: a sort of unit test)."
        },
    )
    max_epochs: int = field(default=10, metadata={"help": "Maximum number of epochs to be run"})
    min_epochs: Optional[int] = field(
        default=1,
        metadata={"help": "Force training for at least these many epochs. 1 by default"},
    )
    max_time: Optional[int] = field(
        default=None,
        metadata={"help": "Stop training after this amount of time has passed. Disabled by default (None)"},
    )
    accelerator: Optional[str] = field(
        default="auto",
        metadata={
            "help": "The accelerator to use for training. Can be one of 'cpu','gpu','tpu','ipu','auto'."
            " Defaults to 'auto'",
            "choices": ["cpu", "gpu", "tpu", "ipu", "mps", "auto"],
        },
    )
    devices: Optional[int] = field(
        default=-1,
        metadata={
            "help": "Number of devices to train on. -1 uses all available devices."
            " By default uses all available devices (-1)",
        },
    )
    devices_list: Optional[List[int]] = field(
        default=None,
        metadata={
            "help": "List of devices to train on (list). If specified, takes precedence over `devices` argument."
            " Defaults to None",
        },
    )

    accumulate_grad_batches: int = field(
        default=1,
        metadata={
            "help": "Accumulates grads every k batches or as set up in the dict."
            " Trainer also calls optimizer.step() for the last indivisible step number."
        },
    )
    auto_lr_find: bool = field(
        default=False,
        metadata={
            "help": "Runs a learning rate finder algorithm (see this paper) when calling trainer.tune(),"
            " to find optimal initial learning rate."
        },
    )
    auto_select_gpus: bool = field(
        default=True,
        metadata={
            "help": "If enabled and `devices` is an integer, pick available gpus automatically."
            " This is especially useful when GPUs are configured to be in 'exclusive mode',"
            " such that only one process at a time can access them."
        },
    )
    check_val_every_n_epoch: int = field(default=1, metadata={"help": "Check val every n train epochs."})
    gradient_clip_val: float = field(default=0.0, metadata={"help": "Gradient clipping value"})
    overfit_batches: float = field(
        default=0.0,
        metadata={
            "help": "Uses this much data of the training set. If nonzero, will use the same training set"
            " for validation and testing. If the training dataloaders have shuffle=True,"
            " Lightning will automatically disable it."
            " Useful for quickly debugging or trying to overfit on purpose."
        },
    )
    deterministic: bool = field(
        default=False,
        metadata={
            "help": "If true enables cudnn.deterministic. Might make your system slower, but ensures reproducibility."
        },
    )
    profiler: Optional[str] = field(
        default=None,
        metadata={
            "help": "To profile individual steps during training and assist in identifying bottlenecks."
            " None, simple or advanced, pytorch",
            "choices": [None, "simple", "advanced", "pytorch"],
        },
    )
    early_stopping: Optional[str] = field(
        default="valid_loss",
        metadata={
            "help": "The loss/metric that needed to be monitored for early stopping."
            " If None, there will be no early stopping"
        },
    )
    early_stopping_min_delta: float = field(
        default=0.001,
        metadata={"help": "The minimum delta in the loss/metric which qualifies as an improvement in early stopping"},
    )
    early_stopping_mode: str = field(
        default="min",
        metadata={
            "help": "The direction in which the loss/metric should be optimized",
            "choices": ["max", "min"],
        },
    )
    early_stopping_patience: int = field(
        default=3,
        metadata={"help": "The number of epochs to wait until there is no further improvements in loss/metric"},
    )
    early_stopping_kwargs: Optional[Dict[str, Any]] = field(
        default_factory=lambda: {},
        metadata={
            "help": "Additional keyword arguments for the early stopping callback."
            " See the documentation for the PyTorch Lightning EarlyStopping callback for more details."
        },
    )
    checkpoints: Optional[str] = field(
        default="valid_loss",
        metadata={
            "help": "The loss/metric that needed to be monitored for checkpoints. If None, there will be no checkpoints"
        },
    )
    checkpoints_path: str = field(
        default="saved_models",
        metadata={"help": "The path where the saved models will be"},
    )
    checkpoints_every_n_epochs: int = field(
        default=1,
        metadata={"help": "Number of training steps between checkpoints"},
    )
    checkpoints_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "The name under which the models will be saved. If left blank,"
            " first it will look for `run_name` in experiment_config and if that is also None"
            " then it will use a generic name like task_version."
        },
    )
    checkpoints_mode: str = field(
        default="min",
        metadata={"help": "The direction in which the loss/metric should be optimized"},
    )
    checkpoints_save_top_k: int = field(
        default=1,
        metadata={"help": "The number of best models to save"},
    )
    checkpoints_kwargs: Optional[Dict[str, Any]] = field(
        default_factory=lambda: {},
        metadata={
            "help": "Additional keyword arguments for the checkpoints callback. See the documentation"
            " for the PyTorch Lightning ModelCheckpoint callback for more details."
        },
    )
    load_best: bool = field(
        default=True,
        metadata={"help": "Flag to load the best model saved during training"},
    )
    track_grad_norm: int = field(
        default=-1,
        metadata={
            "help": "Track and Log Gradient Norms in the logger. -1 by default means no tracking. "
            "1 for the L1 norm, 2 for L2 norm, etc."
        },
    )
    progress_bar: str = field(
        default="rich",
        metadata={"help": "Progress bar type. Can be one of: `none`, `simple`, `rich`. Defaults to `rich`."},
    )
    precision: int = field(
        default=32,
        metadata={
            "help": "Precision of the model. Can be one of: `32`, `16`, `64`. Defaults to `32`.",
            "choices": [32, 16, 64],
        },
    )
    seed: int = field(
        default=42,
        metadata={"help": "Seed for random number generators. Defaults to 42"},
    )
    trainer_kwargs: Dict[str, Any] = field(
        default_factory=dict,
        metadata={"help": "Additional kwargs to be passed to PyTorch Lightning Trainer."},
    )

    def __post_init__(self):
        _validate_choices(self)
        if self.accelerator is None:
            self.accelerator = "cpu"
        if self.devices_list is not None:
            self.devices = self.devices_list
        delattr(self, "devices_list")
        for key in self.early_stopping_kwargs.keys():
            if key in ["min_delta", "mode", "patience"]:
                raise ValueError(
                    f"Cannot override {key} in early_stopping_kwargs."
                    f" Please use the appropriate argument in `TrainerConfig`"
                )
        for key in self.checkpoints_kwargs.keys():
            if key in ["dirpath", "filename", "monitor", "save_top_k", "mode", "every_n_epochs"]:
                raise ValueError(
                    f"Cannot override {key} in checkpoints_kwargs."
                    f" Please use the appropriate argument in `TrainerConfig`"
                )

实验配置.使用 WandB 和 Tensorboard 进行实验跟踪.

Parameters:

Name Type Description Default
project_name str

所有运行日志所属的项目名称.对于 Tensorboard,这定义了日志保存的文件夹, 对于 W&B,这定义了项目名称.

MISSING
run_name Optional[str]

运行的名称;用于识别该运行的特定标识符.如果留空,将自动生成一个名称.

None
exp_watch Optional[str]

所需的日志记录级别.可以是 gradientsparametersallNone. 默认为 None.可选值为:[gradients, parameters, all, None].

None
log_target str

确定日志记录发生的位置 - Tensorboard 或 W&B.可选值为:[wandb, tensorboard].

'tensorboard'
log_logits bool

开启此选项以在 W&B 中将 logits 记录为直方图.

False
exp_log_freq int

记录梯度和参数的步数间隔.

100
Source code in src/pytorch_tabular/config/config.py
@dataclass
class ExperimentConfig:
    """实验配置.使用 WandB 和 Tensorboard 进行实验跟踪.

    Args:
        project_name (str): 所有运行日志所属的项目名称.对于 Tensorboard,这定义了日志保存的文件夹,
                对于 W&B,这定义了项目名称.

        run_name (Optional[str]): 运行的名称;用于识别该运行的特定标识符.如果留空,将自动生成一个名称.

        exp_watch (Optional[str]): 所需的日志记录级别.可以是 `gradients`、`parameters`、`all` 或 `None`.
                默认为 None.可选值为:[`gradients`, `parameters`, `all`, `None`].

        log_target (str): 确定日志记录发生的位置 - Tensorboard 或 W&B.可选值为:[`wandb`, `tensorboard`].

        log_logits (bool): 开启此选项以在 W&B 中将 logits 记录为直方图.

        exp_log_freq (int): 记录梯度和参数的步数间隔."""

    project_name: str = field(
        default=MISSING,
        metadata={
            "help": "The name of the project under which all runs will be logged."
            " For Tensorboard this defines the folder under which the logs will be saved"
            " and for W&B it defines the project name"
        },
    )

    run_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "The name of the run; a specific identifier to recognize the run."
            " If left blank, will be assigned a auto-generated name"
        },
    )
    exp_watch: Optional[str] = field(
        default=None,
        metadata={
            "help": "The level of logging required.  Can be `gradients`, `parameters`, `all` or `None`."
            " Defaults to None",
            "choices": ["gradients", "parameters", "all", None],
        },
    )

    log_target: str = field(
        default="tensorboard",
        metadata={
            "help": "Determines where logging happens - Tensorboard or W&B",
            "choices": ["wandb", "tensorboard"],
        },
    )
    log_logits: bool = field(
        default=False,
        metadata={"help": "Turn this on to log the logits as a histogram in W&B"},
    )

    exp_log_freq: int = field(
        default=100,
        metadata={"help": "step count between logging of gradients and parameters."},
    )

    def __post_init__(self):
        _validate_choices(self)
        if self.log_target == "wandb":
            try:
                import wandb  # noqa: F401
            except ImportError:
                raise ImportError(
                    "No W&B installation detected. `pip install wandb` to install W&B if you set log_target as `wandb`"
                )

优化器和学习率调度器配置.

Parameters:

Name Type Description Default
optimizer str

来自 torch.optim 的标准优化器之一, 或提供完整的 Python 路径,例如 "torch_optimizer.RAdam".

'Adam'
optimizer_params Dict

优化器的参数.如果留空,将使用默认参数.

lambda: {}()
lr_scheduler Optional[str]

要使用的学习率调度器的名称,如果有的话,来自 torch.optim.lr_scheduler.如果为 None,将不使用任何调度器.默认为 None.

None
lr_scheduler_params Optional[Dict]

学习率调度器的参数.如果留空,将使用默认参数.

lambda: {}()
lr_scheduler_monitor_metric Optional[str]

与 ReduceLROnPlateau 一起使用,其中平台是基于此指标决定的.

'valid_loss'
Source code in src/pytorch_tabular/config/config.py
@dataclass
class OptimizerConfig:
    """优化器和学习率调度器配置.

    Parameters:
        optimizer (str): 来自 [torch.optim](https://pytorch.org/docs/stable/optim.html#algorithms) 的标准优化器之一,
                或提供完整的 Python 路径,例如 "torch_optimizer.RAdam".

        optimizer_params (Dict): 优化器的参数.如果留空,将使用默认参数.

        lr_scheduler (Optional[str]): 要使用的学习率调度器的名称,如果有的话,来自
                [torch.optim.lr_scheduler](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-
                rate).如果为 None,将不使用任何调度器.默认为 `None`.

        lr_scheduler_params (Optional[Dict]): 学习率调度器的参数.如果留空,将使用默认参数.

        lr_scheduler_monitor_metric (Optional[str]): 与 ReduceLROnPlateau 一起使用,其中平台是基于此指标决定的."""

    optimizer: str = field(
        default="Adam",
        metadata={
            "help": "Any of the standard optimizers from"
            " [torch.optim](https://pytorch.org/docs/stable/optim.html#algorithms) or provide full python path,"
            " for example 'torch_optimizer.RAdam'."
        },
    )
    optimizer_params: Dict = field(
        default_factory=lambda: {},
        metadata={"help": "The parameters for the optimizer. If left blank, will use default parameters."},
    )
    lr_scheduler: Optional[str] = field(
        default=None,
        metadata={
            "help": "The name of the LearningRateScheduler to use, if any, from"
            " https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate."
            " If None, will not use any scheduler. Defaults to `None`",
        },
    )
    lr_scheduler_params: Optional[Dict] = field(
        default_factory=lambda: {},
        metadata={"help": "The parameters for the LearningRateScheduler. If left blank, will use default parameters."},
    )

    lr_scheduler_monitor_metric: Optional[str] = field(
        default="valid_loss",
        metadata={"help": "Used with ReduceLROnPlateau, where the plateau is decided based on this metric"},
    )

    @staticmethod
    def read_from_yaml(filename: str = "config/optimizer_config.yml"):
        config = _read_yaml(filename)
        if config["lr_scheduler_params"] is None:
            config["lr_scheduler_params"] = {}
        return OptimizerConfig(**config)
Source code in src/pytorch_tabular/config/config.py
class ExperimentRunManager:
    def __init__(
        self,
        exp_version_manager: str = ".pt_tmp/exp_version_manager.yml",
    ) -> None:
        """    管理基于名称的实验版本.它是一个简单的基于字典(yaml)的查找.
主要目的是在运行训练而不更改实验名称时避免保存模型的覆盖.

Parameters:
    exp_version_manager (str, 可选): 作为版本控制的yml文件的路径.
        默认为 ".pt_tmp/exp_version_manager.yml".
"""
        super().__init__()
        self._exp_version_manager = exp_version_manager
        if os.path.exists(exp_version_manager):
            self.exp_version_manager = OmegaConf.load(exp_version_manager)
        else:
            self.exp_version_manager = OmegaConf.create({})
            os.makedirs(os.path.split(exp_version_manager)[0], exist_ok=True)
            with open(self._exp_version_manager, "w") as file:
                OmegaConf.save(config=self.exp_version_manager, f=file)

    def update_versions(self, name):
        if name in self.exp_version_manager.keys():
            uid = self.exp_version_manager[name] + 1
        else:
            uid = 1
        self.exp_version_manager[name] = uid
        with open(self._exp_version_manager, "w") as file:
            OmegaConf.save(config=self.exp_version_manager, f=file)
        return uid

__init__(exp_version_manager='.pt_tmp/exp_version_manager.yml')

管理基于名称的实验版本.它是一个简单的基于字典(yaml)的查找. 主要目的是在运行训练而不更改实验名称时避免保存模型的覆盖.

Parameters:

Name Type Description Default
exp_version_manager (str, 可选)

作为版本控制的yml文件的路径. 默认为 ".pt_tmp/exp_version_manager.yml".

'.pt_tmp/exp_version_manager.yml'
Source code in src/pytorch_tabular/config/config.py
    def __init__(
        self,
        exp_version_manager: str = ".pt_tmp/exp_version_manager.yml",
    ) -> None:
        """    管理基于名称的实验版本.它是一个简单的基于字典(yaml)的查找.
主要目的是在运行训练而不更改实验名称时避免保存模型的覆盖.

Parameters:
    exp_version_manager (str, 可选): 作为版本控制的yml文件的路径.
        默认为 ".pt_tmp/exp_version_manager.yml".
"""
        super().__init__()
        self._exp_version_manager = exp_version_manager
        if os.path.exists(exp_version_manager):
            self.exp_version_manager = OmegaConf.load(exp_version_manager)
        else:
            self.exp_version_manager = OmegaConf.create({})
            os.makedirs(os.path.split(exp_version_manager)[0], exist_ok=True)
            with open(self._exp_version_manager, "w") as file:
                OmegaConf.save(config=self.exp_version_manager, f=file)

头部配置

除了这些核心类之外,我们还有用于头部的配置类

线性头配置的模型类;作为模板和文档使用.模型接受字典作为输入,但如果存在本模型类中不存在的键,则会抛出异常.

Parameters:

Name Type Description Default
layers str

分类/回归头中层数和单元数的连字符分隔字符串. 例如:32-64-32.默认情况下,仅从输入维度映射到输出维度.

''
activation str

分类头中的激活类型.默认激活类型类似于PyTorch中的ReLU、TanH、LeakyReLU等. 参考:https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity

'ReLU'
dropout float

分类元素被置零的概率.

0.0
use_batch_norm bool

标志,用于在每个线性层+DropOut后添加BatchNorm层.

False
initialization str

线性层的初始化方案.默认为kaiming.可选方案有:[kaiming,xavier,random].

'kaiming'
Source code in src/pytorch_tabular/models/common/heads/config.py
@dataclass
class LinearHeadConfig:
    """线性头配置的模型类;作为模板和文档使用.模型接受字典作为输入,但如果存在本模型类中不存在的键,则会抛出异常.

    Args:
        layers (str): 分类/回归头中层数和单元数的连字符分隔字符串.
                例如:32-64-32.默认情况下,仅从输入维度映射到输出维度.

        activation (str): 分类头中的激活类型.默认激活类型类似于PyTorch中的ReLU、TanH、LeakyReLU等.
                参考:https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity

        dropout (float): 分类元素被置零的概率.

        use_batch_norm (bool): 标志,用于在每个线性层+DropOut后添加BatchNorm层.

        initialization (str): 线性层的初始化方案.默认为`kaiming`.可选方案有:[`kaiming`,`xavier`,`random`]."""

    layers: str = field(
        default="",
        metadata={
            "help": "Hyphen-separated number of layers and units in the classification/regression head. eg. 32-64-32."
            " Default is just a mapping from intput dimension to output dimension"
        },
    )
    activation: str = field(
        default="ReLU",
        metadata={
            "help": "The activation type in the classification head. The default activation in PyTorch"
            " like ReLU, TanH, LeakyReLU, etc."
            " https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity"
        },
    )
    dropout: float = field(
        default=0.0,
        metadata={"help": "probability of an classification element to be zeroed."},
    )
    use_batch_norm: bool = field(
        default=False,
        metadata={"help": "Flag to include a BatchNorm layer after each Linear Layer+DropOut"},
    )
    initialization: str = field(
        default="kaiming",
        metadata={
            "help": "Initialization scheme for the linear layers. Defaults to `kaiming`",
            "choices": ["kaiming", "xavier", "random"],
        },
    )

混合密度网络头配置.

Parameters:

Name Type Description Default
num_gaussian int

混合模型中高斯分布的数量.默认为1

1
sigma_bias_flag bool

是否在sigma层中包含偏置项.默认为False

False
mu_bias_init Optional[List]

将mu层的偏置参数初始化为预定义的聚类中心.应为一个与混合模型中高斯数量相同长度的列表.强烈建议设置此参数以对抗模式崩溃.默认为None

None
weight_regularization Optional[int]

是否对MDN层应用L1或L2范数.默认为L2.可选值为: [1,2]

2
lambda_sigma Optional[float]

sigma层权重正则化的正则化常数.默认为0.1

0.1
lambda_pi Optional[float]

pi层权重正则化的正则化常数.默认为0.1

0.1
lambda_mu Optional[float]

mu层权重正则化的正则化常数.默认为0

0
softmax_temperature Optional[float]

用于混合系数gumbel softmax的温度.小于1的值会导致多个成分之间的过渡更尖锐.默认为1

1
n_samples int

从后验分布中抽取样本以获得预测的数量.默认为100

100
central_tendency str

用于获取点预测的度量方法.默认为均值.可选值为: [mean,median]

'mean'
speedup_training bool

开启此参数将取消训练期间的采样,从而加快训练速度,但也会使您无法查看训练指标.默认为False

False
log_debug_plot bool

开启此参数将绘制mu、sigma和pi层的直方图,以及logits(如果在实验配置中开启了log_logits).默认为False

False
input_dim int

输入到头部的维度.这将在从backbone.output_dim初始化时自动填充

None
Source code in src/pytorch_tabular/models/common/heads/config.py
@dataclass
class MixtureDensityHeadConfig:
    """混合密度网络头配置.

    Parameters:
        num_gaussian (int): 混合模型中高斯分布的数量.默认为1

        sigma_bias_flag (bool): 是否在sigma层中包含偏置项.默认为False

        mu_bias_init (Optional[List]): 将mu层的偏置参数初始化为预定义的聚类中心.应为一个与混合模型中高斯数量相同长度的列表.强烈建议设置此参数以对抗模式崩溃.默认为None

        weight_regularization (Optional[int]): 是否对MDN层应用L1或L2范数.默认为L2.可选值为: [`1`,`2`]

        lambda_sigma (Optional[float]): sigma层权重正则化的正则化常数.默认为0.1

        lambda_pi (Optional[float]): pi层权重正则化的正则化常数.默认为0.1

        lambda_mu (Optional[float]): mu层权重正则化的正则化常数.默认为0

        softmax_temperature (Optional[float]): 用于混合系数gumbel softmax的温度.小于1的值会导致多个成分之间的过渡更尖锐.默认为1

        n_samples (int): 从后验分布中抽取样本以获得预测的数量.默认为100

        central_tendency (str): 用于获取点预测的度量方法.默认为均值.可选值为: [`mean`,`median`]

        speedup_training (bool): 开启此参数将取消训练期间的采样,从而加快训练速度,但也会使您无法查看训练指标.默认为False

        log_debug_plot (bool): 开启此参数将绘制mu、sigma和pi层的直方图,以及logits(如果在实验配置中开启了log_logits).默认为False

        input_dim (int): 输入到头部的维度.这将在从`backbone.output_dim`初始化时自动填充"""

    num_gaussian: int = field(
        default=1,
        metadata={
            "help": "Number of Gaussian Distributions in the mixture model. Defaults to 1",
        },
    )
    sigma_bias_flag: bool = field(
        default=False,
        metadata={
            "help": "Whether to have a bias term in the sigma layer. Defaults to False",
        },
    )
    mu_bias_init: Optional[List] = field(
        default=None,
        metadata={
            "help": "To initialize the bias parameter of the mu layer to predefined cluster centers."
            " Should be a list with the same length as number of gaussians in the mixture model."
            " It is highly recommended to set the parameter to combat mode collapse. Defaults to None",
        },
    )

    weight_regularization: Optional[int] = field(
        default=2,
        metadata={
            "help": "Whether to apply L1 or L2 Norm to the MDN layers. Defaults to L2",
            "choices": [1, 2],
        },
    )

    lambda_sigma: Optional[float] = field(
        default=0.1,
        metadata={
            "help": "The regularization constant for weight regularization of sigma layer. Defaults to 0.1",
        },
    )
    lambda_pi: Optional[float] = field(
        default=0.1,
        metadata={
            "help": "The regularization constant for weight regularization of pi layer. Defaults to 0.1",
        },
    )
    lambda_mu: Optional[float] = field(
        default=0,
        metadata={
            "help": "The regularization constant for weight regularization of mu layer. Defaults to 0",
        },
    )
    softmax_temperature: Optional[float] = field(
        default=1,
        metadata={
            "help": "The temperature to be used in the gumbel softmax of the mixing coefficients."
            " Values less than one leads to sharper transition between the multiple components. Defaults to 1",
        },
    )
    n_samples: int = field(
        default=100,
        metadata={
            "help": "Number of samples to draw from the posterior to get prediction. Defaults to 100",
        },
    )
    central_tendency: str = field(
        default="mean",
        metadata={
            "help": "Which measure to use to get the point prediction. Defaults to mean",
            "choices": ["mean", "median"],
        },
    )
    speedup_training: bool = field(
        default=False,
        metadata={
            "help": "Turning on this parameter does away with sampling during training which speeds up training,"
            " but also doesn't give you visibility on train metrics. Defaults to False",
        },
    )
    log_debug_plot: bool = field(
        default=False,
        metadata={
            "help": "Turning on this parameter plots histograms of the mu, sigma, and pi layers in addition"
            " to the logits(if log_logits is turned on in experment config). Defaults to False",
        },
    )
    input_dim: int = field(
        default=None,
        metadata={
            "help": "The input dimensions to the head. This will be automatically filled in while initializing"
            " from the `backbone.output_dim`",
        },
    )
    _probabilistic: bool = field(default=True)