数据
PyTorch Tabular 使用 Pandas DataFrame 作为数据容器。由于 Pandas 是处理表格数据最流行的方式,这是一个显而易见的选择。考虑到易用性,PyTorch Tabular 直接接受 DataFrame,即不需要像 Sci-kit Learn 那样将数据拆分为 X
和 y
。
PyTorch Tabular 通过 DataConfig
对象来处理这些数据。
基本用法¶
target
: List[str]: 包含目标列名称的字符串列表continuous_cols
: List[str]: 数值字段的列名。默认为 []categorical_cols
: List[str]: 需要特殊处理的分类字段的列名
使用示例¶
data_config = DataConfig(
target=["label"],
continuous_cols=["feature_1", "feature_2"],
categorical_cols=["cat_feature_1", "cat_feature_2"],
)
高级用法:¶
日期列¶
如果你的 DataFrame 中有日期列,可以在 date_columns
参数中提及列名,并将 encode_date_columns
设置为 True
。这将提取相关特征,如月份、周、季度等,并在内部将其添加到特征列表中。
date_columns
不仅是一个列名列表,而是一个 (列名, 频率) 元组的列表。频率是标准的 Pandas 日期频率标签,表示问题中相关的时间粒度。
例如,如果有一个产品发布日期的列,并且它们每月只发布一次。那么提取周或日等特征是没有意义的。因此,我们将频率保持在 M
特征变换¶
特征缩放几乎是大多数机器学习算法获得良好性能的必要步骤,深度学习也不例外。normalize_continuous_features
标志(默认为 True
)使用 StandardScaler
对输入的连续特征进行缩放。
有时,使用非线性变换改变特征分布有助于机器学习/深度学习算法。
PyTorch Tabular 通过 continuous_feature_transform
参数提供了 4 种标准变换:
yeo-johnson
box-cox
quantile_uniform
quantile_normal
yeo-johnson
和 box-cox
是一组参数化的单调变换,旨在将数据从任何分布映射到尽可能接近高斯分布,以稳定方差并最小化偏度。box-cox
只能应用于严格正数据。Sci-kit Learn 对此有很好的说明
quantile_normal
和 quantile_uniform
是单调的非参数变换,分别旨在将特征转换为正态分布或均匀分布。通过执行秩变换,分位数变换平滑了不寻常的分布,并且比缩放方法受异常值的影响更小。然而,它确实扭曲了特征内和特征间的相关性和距离。
pytorch_tabular.config.DataConfig
dataclass
¶
数据配置.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
target
|
Optional[List[str]]
|
目标列名称的字符串列表.对于SSL任务以外的所有任务都是必需的. |
None
|
continuous_cols
|
List
|
数值字段的列名称.默认为 [] |
list()
|
categorical_cols
|
List
|
需要特殊处理的分类字段的列名称.默认为 [] |
list()
|
date_columns
|
List
|
(列名, 频率, 格式) 元组形式的日期字段.例如,名为 introduction_date 且频率为 "2023-12" 的月度字段应有一个条目 ('intro_date','M','%Y-%m') |
list()
|
encode_date_columns
|
bool
|
是否对从日期派生的变量进行编码 |
True
|
validation_split
|
Optional[float]
|
保留为验证集的训练行百分比.仅在未单独提供验证数据时使用 |
0.2
|
continuous_feature_transform
|
Optional[str]
|
是否在建模前对特征进行变换.默认关闭.可选值为: [ |
None
|
normalize_continuous_features
|
bool
|
是否对输入特征(连续)进行归一化 |
True
|
quantile_noise
|
int
|
未实现.如果指定,将在数据上拟合 QuantileTransformer,并添加标准差为 :quantile_noise: * data.std 的高斯噪声;这将使离散值更易分离.请注意,此变换不会对结果数据应用高斯噪声,噪声仅用于 QuantileTransformer |
0
|
num_workers
|
Optional[int]
|
用于数据加载的工作线程数.对于 Windows 系统始终设置为 0 |
0
|
pin_memory
|
bool
|
是否为数据加载固定内存 |
True
|
handle_unknown_categories
|
bool
|
是否将分类列中的未知或新值处理为未知 |
True
|
handle_missing_values
|
bool
|
是否将分类列中的缺失值处理为未知 |
True
|
Source code in src/pytorch_tabular/config/config.py
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
|