import torch
import torch.nn as nn
from ..configs.mlp_config import DefaultMLPConfig
from .basemodel import BaseModel
from ..arch_utils.normalization_layers import (
RMSNorm,
LayerNorm,
LearnableLayerScaling,
BatchNorm,
InstanceNorm,
GroupNorm,
)
from ..arch_utils.embedding_layer import EmbeddingLayer
[文档]class MLP(BaseModel):
def __init__(
self,
cat_feature_info,
num_feature_info,
num_classes: int = 1,
config: DefaultMLPConfig = DefaultMLPConfig(),
**kwargs,
):
"""
Initializes the MLP model with the given configuration.
Parameters
----------
cat_feature_info : Any
Information about categorical features.
num_feature_info : Any
Information about numerical features.
num_classes : int, optional
Number of output classes, by default 1.
config : DefaultMLPConfig, optional
Configuration dataclass containing hyperparameters, by default DefaultMLPConfig().
"""
super().__init__(**kwargs)
self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
self.lr = self.hparams.get("lr", config.lr)
self.lr_patience = self.hparams.get("lr_patience", config.lr_patience)
self.weight_decay = self.hparams.get("weight_decay", config.weight_decay)
self.lr_factor = self.hparams.get("lr_factor", config.lr_factor)
self.cat_feature_info = cat_feature_info
self.num_feature_info = num_feature_info
# Initialize layers
self.layers = nn.ModuleList()
self.skip_connections = self.hparams.get(
"skip_connections", config.skip_connections
)
self.use_glu = self.hparams.get("use_glu", config.use_glu)
self.activation = self.hparams.get("activation", config.activation)
self.use_embeddings = self.hparams.get("use_embeddings", config.use_embeddings)
input_dim = 0
for feature_name, input_shape in num_feature_info.items():
input_dim += input_shape
for feature_name, input_shape in cat_feature_info.items():
input_dim += 1
if self.use_embeddings:
input_dim = (
len(num_feature_info) * config.d_model
+ len(cat_feature_info) * config.d_model
)
# Input layer
self.layers.append(nn.Linear(input_dim, config.layer_sizes[0]))
if config.batch_norm:
self.layers.append(nn.BatchNorm1d(config.layer_sizes[0]))
norm_layer = self.hparams.get("norm", config.norm)
if norm_layer == "RMSNorm":
self.norm_f = RMSNorm(config.layer_sizes[0])
elif norm_layer == "LayerNorm":
self.norm_f = LayerNorm(config.layer_sizes[0])
elif norm_layer == "BatchNorm":
self.norm_f = BatchNorm(config.layer_sizes[0])
elif norm_layer == "InstanceNorm":
self.norm_f = InstanceNorm(config.layer_sizes[0])
elif norm_layer == "GroupNorm":
self.norm_f = GroupNorm(1, config.layer_sizes[0])
elif norm_layer == "LearnableLayerScaling":
self.norm_f = LearnableLayerScaling(config.layer_sizes[0])
else:
self.norm_f = None
if self.norm_f is not None:
self.layers.append(self.norm_f(config.layer_sizes[0]))
if config.use_glu:
self.layers.append(nn.GLU())
else:
self.layers.append(self.activation)
if config.dropout > 0.0:
self.layers.append(nn.Dropout(config.dropout))
# Hidden layers
for i in range(1, len(config.layer_sizes)):
self.layers.append(
nn.Linear(config.layer_sizes[i - 1], config.layer_sizes[i])
)
if config.batch_norm:
self.layers.append(nn.BatchNorm1d(config.layer_sizes[i]))
if config.layer_norm:
self.layers.append(nn.LayerNorm(config.layer_sizes[i]))
if config.use_glu:
self.layers.append(nn.GLU())
else:
self.layers.append(self.activation)
if config.dropout > 0.0:
self.layers.append(nn.Dropout(config.dropout))
# Output layer
self.layers.append(nn.Linear(config.layer_sizes[-1], num_classes))
if self.use_embeddings:
self.embedding_layer = EmbeddingLayer(
num_feature_info=num_feature_info,
cat_feature_info=cat_feature_info,
d_model=self.hparams.get("d_model", config.d_model),
embedding_activation=self.hparams.get(
"embedding_activation", config.embedding_activation
),
layer_norm_after_embedding=self.hparams.get(
"layer_norm_after_embedding"
),
use_cls=False,
)
[文档] def forward(self, num_features, cat_features) -> torch.Tensor:
"""
Forward pass of the MLP model.
Parameters
----------
x : torch.Tensor
Input tensor.
Returns
-------
torch.Tensor
Output tensor.
"""
if self.use_embeddings:
x = self.embedding_layer(num_features, cat_features)
B, S, D = x.shape
x = x.reshape(B, S * D)
else:
x = num_features + cat_features
x = torch.cat(x, dim=1)
for i in range(len(self.layers) - 1):
if isinstance(self.layers[i], nn.Linear):
out = self.layers[i](x)
if self.skip_connections and x.shape == out.shape:
x = x + out
else:
x = out
else:
x = self.layers[i](x)
x = self.layers[-1](x)
return x