Skip to content

数据

PyTorch Tabular 使用 Pandas DataFrame 作为数据容器。由于 Pandas 是处理表格数据最流行的方式,这是一个显而易见的选择。考虑到易用性,PyTorch Tabular 直接接受 DataFrame,即不需要像 Sci-kit Learn 那样将数据拆分为 Xy

PyTorch Tabular 通过 DataConfig 对象来处理这些数据。

基本用法

  • target: List[str]: 包含目标列名称的字符串列表
  • continuous_cols: List[str]: 数值字段的列名。默认为 []
  • categorical_cols: List[str]: 需要特殊处理的分类字段的列名

使用示例

data_config = DataConfig(
    target=["label"],
    continuous_cols=["feature_1", "feature_2"],
    categorical_cols=["cat_feature_1", "cat_feature_2"],
)

高级用法:

日期列

如果你的 DataFrame 中有日期列,可以在 date_columns 参数中提及列名,并将 encode_date_columns 设置为 True。这将提取相关特征,如月份、周、季度等,并在内部将其添加到特征列表中。

date_columns 不仅是一个列名列表,而是一个 (列名, 频率) 元组的列表。频率是标准的 Pandas 日期频率标签,表示问题中相关的时间粒度。

例如,如果有一个产品发布日期的列,并且它们每月只发布一次。那么提取周或日等特征是没有意义的。因此,我们将频率保持在 M

date_columns = [("launch_date", "M")]

特征变换

特征缩放几乎是大多数机器学习算法获得良好性能的必要步骤,深度学习也不例外。normalize_continuous_features 标志(默认为 True)使用 StandardScaler 对输入的连续特征进行缩放。

有时,使用非线性变换改变特征分布有助于机器学习/深度学习算法。

PyTorch Tabular 通过 continuous_feature_transform 参数提供了 4 种标准变换:

  • yeo-johnson
  • box-cox
  • quantile_uniform
  • quantile_normal

yeo-johnsonbox-cox 是一组参数化的单调变换,旨在将数据从任何分布映射到尽可能接近高斯分布,以稳定方差并最小化偏度。box-cox 只能应用于严格正数据。Sci-kit Learn 对此有很好的说明

quantile_normalquantile_uniform 是单调的非参数变换,分别旨在将特征转换为正态分布或均匀分布。通过执行秩变换,分位数变换平滑了不寻常的分布,并且比缩放方法受异常值的影响更小。然而,它确实扭曲了特征内和特征间的相关性和距离。

pytorch_tabular.config.DataConfig dataclass

数据配置.

Parameters:

Name Type Description Default
target Optional[List[str]]

目标列名称的字符串列表.对于SSL任务以外的所有任务都是必需的.

None
continuous_cols List

数值字段的列名称.默认为 []

list()
categorical_cols List

需要特殊处理的分类字段的列名称.默认为 []

list()
date_columns List

(列名, 频率, 格式) 元组形式的日期字段.例如,名为 introduction_date 且频率为 "2023-12" 的月度字段应有一个条目 ('intro_date','M','%Y-%m')

list()
encode_date_columns bool

是否对从日期派生的变量进行编码

True
validation_split Optional[float]

保留为验证集的训练行百分比.仅在未单独提供验证数据时使用

0.2
continuous_feature_transform Optional[str]

是否在建模前对特征进行变换.默认关闭.可选值为: [None,yeo-johnson,box-cox, quantile_normal,quantile_uniform].

None
normalize_continuous_features bool

是否对输入特征(连续)进行归一化

True
quantile_noise int

未实现.如果指定,将在数据上拟合 QuantileTransformer,并添加标准差为 :quantile_noise: * data.std 的高斯噪声;这将使离散值更易分离.请注意,此变换不会对结果数据应用高斯噪声,噪声仅用于 QuantileTransformer

0
num_workers Optional[int]

用于数据加载的工作线程数.对于 Windows 系统始终设置为 0

0
pin_memory bool

是否为数据加载固定内存

True
handle_unknown_categories bool

是否将分类列中的未知或新值处理为未知

True
handle_missing_values bool

是否将分类列中的缺失值处理为未知

True
Source code in src/pytorch_tabular/config/config.py
@dataclass
class DataConfig:
    """数据配置.

    Parameters:
        target (Optional[List[str]]): 目标列名称的字符串列表.对于SSL任务以外的所有任务都是必需的.

        continuous_cols (List): 数值字段的列名称.默认为 []

        categorical_cols (List): 需要特殊处理的分类字段的列名称.默认为 []

        date_columns (List): (列名, 频率, 格式) 元组形式的日期字段.例如,名为 introduction_date 且频率为 "2023-12" 的月度字段应有一个条目 ('intro_date','M','%Y-%m')

        encode_date_columns (bool): 是否对从日期派生的变量进行编码

        validation_split (Optional[float]): 保留为验证集的训练行百分比.仅在未单独提供验证数据时使用

        continuous_feature_transform (Optional[str]): 是否在建模前对特征进行变换.默认关闭.可选值为: [`None`,`yeo-johnson`,`box-cox`, `quantile_normal`,`quantile_uniform`].

        normalize_continuous_features (bool): 是否对输入特征(连续)进行归一化

        quantile_noise (int): 未实现.如果指定,将在数据上拟合 QuantileTransformer,并添加标准差为 :quantile_noise: * data.std 的高斯噪声;这将使离散值更易分离.请注意,此变换不会对结果数据应用高斯噪声,噪声仅用于 QuantileTransformer

        num_workers (Optional[int]): 用于数据加载的工作线程数.对于 Windows 系统始终设置为 0

        pin_memory (bool): 是否为数据加载固定内存

        handle_unknown_categories (bool): 是否将分类列中的未知或新值处理为未知

        handle_missing_values (bool): 是否将分类列中的缺失值处理为未知"""

    target: Optional[List[str]] = field(
        default=None,
        metadata={
            "help": "A list of strings with the names of the target column(s)."
            " It is mandatory for all except SSL tasks."
        },
    )
    continuous_cols: List = field(
        default_factory=list,
        metadata={"help": "Column names of the numeric fields. Defaults to []"},
    )
    categorical_cols: List = field(
        default_factory=list,
        metadata={"help": "Column names of the categorical fields to treat differently. Defaults to []"},
    )
    date_columns: List = field(
        default_factory=list,
        metadata={
            "help": "(Column names, Freq) tuples of the date fields. For eg. a field named"
            " introduction_date and with a monthly frequency like '2023-12' should have"
            " an entry ('intro_date','M','%Y-%m')"
        },
    )

    encode_date_columns: bool = field(
        default=True,
        metadata={"help": "Whether or not to encode the derived variables from date"},
    )
    validation_split: Optional[float] = field(
        default=0.2,
        metadata={
            "help": "Percentage of Training rows to keep aside as validation."
            " Used only if Validation Data is not given separately"
        },
    )
    continuous_feature_transform: Optional[str] = field(
        default=None,
        metadata={
            "help": "Whether or not to transform the features before modelling. By default it is turned off.",
            "choices": [
                None,
                "yeo-johnson",
                "box-cox",
                "quantile_normal",
                "quantile_uniform",
            ],
        },
    )
    normalize_continuous_features: bool = field(
        default=True,
        metadata={"help": "Flag to normalize the input features (continuous)"},
    )
    quantile_noise: int = field(
        default=0,
        metadata={
            "help": "NOT IMPLEMENTED. If specified fits QuantileTransformer on data with added gaussian noise"
            " with std = :quantile_noise: * data.std ; this will cause discrete values to be more separable."
            " Please not that this transformation does NOT apply gaussian noise to the resulting data,"
            " the noise is only applied for QuantileTransformer"
        },
    )
    num_workers: Optional[int] = field(
        default=0,
        metadata={"help": "The number of workers used for data loading. For windows always set to 0"},
    )
    pin_memory: bool = field(
        default=True,
        metadata={"help": "Whether or not to pin memory for data loading."},
    )
    handle_unknown_categories: bool = field(
        default=True,
        metadata={"help": "Whether or not to handle unknown or new values in categorical columns as unknown"},
    )
    handle_missing_values: bool = field(
        default=True,
        metadata={"help": "Whether or not to handle missing values in categorical columns as unknown"},
    )

    def __post_init__(self):
        assert (
            len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0
        ), "There should be at-least one feature defined in categorical, continuous, or date columns"
        _validate_choices(self)
        if os.name == "nt" and self.num_workers != 0:
            print("Windows does not support num_workers > 0. Setting num_workers to 0")
            self.num_workers = 0