头部模型
配置类¶
线性头配置的模型类;作为模板和文档使用.模型接受字典作为输入,但如果存在本模型类中不存在的键,则会抛出异常.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
layers
|
str
|
分类/回归头中层数和单元数的连字符分隔字符串. 例如:32-64-32.默认情况下,仅从输入维度映射到输出维度. |
''
|
activation
|
str
|
分类头中的激活类型.默认激活类型类似于PyTorch中的ReLU、TanH、LeakyReLU等. 参考:https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity |
'ReLU'
|
dropout
|
float
|
分类元素被置零的概率. |
0.0
|
use_batch_norm
|
bool
|
标志,用于在每个线性层+DropOut后添加BatchNorm层. |
False
|
initialization
|
str
|
线性层的初始化方案.默认为 |
'kaiming'
|
Source code in src/pytorch_tabular/models/common/heads/config.py
混合密度网络头配置.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
num_gaussian
|
int
|
混合模型中高斯分布的数量.默认为1 |
1
|
sigma_bias_flag
|
bool
|
是否在sigma层中包含偏置项.默认为False |
False
|
mu_bias_init
|
Optional[List]
|
将mu层的偏置参数初始化为预定义的聚类中心.应为一个与混合模型中高斯数量相同长度的列表.强烈建议设置此参数以对抗模式崩溃.默认为None |
None
|
weight_regularization
|
Optional[int]
|
是否对MDN层应用L1或L2范数.默认为L2.可选值为: [ |
2
|
lambda_sigma
|
Optional[float]
|
sigma层权重正则化的正则化常数.默认为0.1 |
0.1
|
lambda_pi
|
Optional[float]
|
pi层权重正则化的正则化常数.默认为0.1 |
0.1
|
lambda_mu
|
Optional[float]
|
mu层权重正则化的正则化常数.默认为0 |
0
|
softmax_temperature
|
Optional[float]
|
用于混合系数gumbel softmax的温度.小于1的值会导致多个成分之间的过渡更尖锐.默认为1 |
1
|
n_samples
|
int
|
从后验分布中抽取样本以获得预测的数量.默认为100 |
100
|
central_tendency
|
str
|
用于获取点预测的度量方法.默认为均值.可选值为: [ |
'mean'
|
speedup_training
|
bool
|
开启此参数将取消训练期间的采样,从而加快训练速度,但也会使您无法查看训练指标.默认为False |
False
|
log_debug_plot
|
bool
|
开启此参数将绘制mu、sigma和pi层的直方图,以及logits(如果在实验配置中开启了log_logits).默认为False |
False
|
input_dim
|
int
|
输入到头部的维度.这将在从 |
None
|
Source code in src/pytorch_tabular/models/common/heads/config.py
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
|
头部类¶
Bases: Head
Source code in src/pytorch_tabular/models/common/heads/blocks.py
Bases: Module
Source code in src/pytorch_tabular/models/common/heads/blocks.py
gaussian_probability(sigma, mu, target, log=False)
¶
返回在给定高斯混合模型参数 sigma
和 mu
的条件下,target
的概率.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sigma
|
BxGxO
|
高斯分布的标准差.B 是批量大小,G 是高斯分布的数量,O 是每个高斯分布的维度数. |
required |
mu
|
BxGxO
|
高斯分布的均值.B 是批量大小,G 是高斯分布的数量,O 是每个高斯分布的维度数. |
required |
target
|
BxI
|
目标的批量.B 是批量大小,I 是输入维度数. |
required |
Returns: probabilities (BxG): 分布在相应 sigma/mu 索引中每个点的概率.
Source code in src/pytorch_tabular/models/common/heads/blocks.py
sample(pi, sigma, mu)
¶
从高斯混合模型 (MoG) 中抽取样本.