自监督模型
配置类¶
Bases: SSLModelConfig
去噪自动编码器配置.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
noise_strategy
|
str
|
定义我们向样本引入什么样的噪声. |
'swap'
|
noise_probabilities
|
Dict[str, float]
|
用于以交换/零噪声破坏输入特征的个体概率字典.键应为特征名称,如果缺少任何特征,则使用default_noise_probability.默认为空字典() |
lambda: {}()
|
default_noise_probability
|
float
|
用于以交换/零噪声破坏输入特征的默认概率.对于noise_probabilities未定义概率的特征.默认为0.8 |
0.8
|
loss_type_weights
|
Optional[List[float]]
|
用于损失函数的权重,顺序为[二进制, 分类, 数值].如果为None,将使用默认权重,使用公式计算.例如,对于二进制,默认权重将为n_binary/n_features.默认为None |
None
|
mask_loss_weight
|
float
|
用于掩码特征损失函数的权重.默认为1.0 |
2.0
|
max_onehot_cardinality
|
int
|
独热编码分类特征的最大基数.任何基数>max_onehot_cardinality的分类特征将在学习的嵌入空间中嵌入,其他特征将转换为独热表示.如果设置为0,将对所有分类特征使用嵌入策略.默认为4 |
4
|
include_input_features_inference
|
bool
|
如果为True,将在微调时包含输入特征以及学习到的特征.默认为False |
False
|
encoder_config
|
Optional[ModelConfig]
|
用于模型的编码器的配置.应为PyTorch Tabular中定义的模型配置之一 |
None
|
decoder_config
|
Optional[ModelConfig]
|
用于模型的解码器的配置.应为PyTorch Tabular中定义的模型配置之一.默认为nn.Identity |
None
|
embedding_dims
|
Optional[List]
|
每个分类列的嵌入维度,格式为(基数, 嵌入维度)的列表.如果为空,将根据分类列的基数推断,使用规则min(50, (x + 1) // 2) |
None
|
embedding_dropout
|
float
|
应用于分类嵌入的丢弃率.默认为0.1 |
0.1
|
batch_norm_continuous_input
|
bool
|
如果为True,我们将通过BatchNorm层对连续层进行归一化. |
True
|
learning_rate
|
float
|
模型的学习率.默认为1e-3 |
0.001
|
seed
|
int
|
用于可重复性的种子.默认为42 |
42
|
Source code in src/pytorch_tabular/ssl_models/dae/config.py
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
|
模型类¶
Bases: SSLBaseModel
Source code in src/pytorch_tabular/ssl_models/dae/dae.py
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 |
|
基础模型类¶
Bases: LightningModule
Source code in src/pytorch_tabular/ssl_models/base_model.py
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
|
__init__(config, mode='pretrain', encoder=None, decoder=None, custom_optimizer=None, custom_optimizer_params={}, **kwargs)
¶
所有SSL模型的基础模型.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config
|
DictConfig
|
用户定义的配置 |
required |
mode
|
(str, 可选)
|
模型的模式.默认为 "pretrain". |
'pretrain'
|
encoder
|
(Optional[Module], 可选)
|
模型的编码器.默认为 None. |
None
|
decoder
|
(Optional[Module], 可选)
|
模型的解码器.默认为 None. |
None
|
custom_optimizer
|
(Optional[Optimizer], 可选)
|
要使用的自定义优化器.默认为 None. |
None
|
custom_optimizer_params
|
(Dict, 可选)
|
要使用的自定义优化器参数.默认为 {}. |
{}
|