mambular.base_models.tabtransformer 源代码

import torch
import torch.nn as nn
from ..arch_utils.mlp_utils import MLP
from ..arch_utils.normalization_layers import (
    RMSNorm,
    LayerNorm,
    LearnableLayerScaling,
    BatchNorm,
    InstanceNorm,
    GroupNorm,
)
from ..arch_utils.embedding_layer import EmbeddingLayer
from ..configs.tabtransformer_config import DefaultTabTransformerConfig
from .basemodel import BaseModel
from ..arch_utils.transformer_utils import CustomTransformerEncoderLayer


[文档]class TabTransformer(BaseModel): """ A PyTorch model for tasks utilizing the Transformer architecture and various normalization techniques. Parameters ---------- cat_feature_info : dict Dictionary containing information about categorical features. num_feature_info : dict Dictionary containing information about numerical features. num_classes : int, optional Number of output classes (default is 1). config : DefaultFTTransformerConfig, optional Configuration object containing default hyperparameters for the model (default is DefaultMambularConfig()). **kwargs : dict Additional keyword arguments. Attributes ---------- lr : float Learning rate. lr_patience : int Patience for learning rate scheduler. weight_decay : float Weight decay for optimizer. lr_factor : float Factor by which the learning rate will be reduced. pooling_method : str Method to pool the features. cat_feature_info : dict Dictionary containing information about categorical features. num_feature_info : dict Dictionary containing information about numerical features. embedding_activation : callable Activation function for embeddings. encoder: callable stack of N encoder layers norm_f : nn.Module Normalization layer. num_embeddings : nn.ModuleList Module list for numerical feature embeddings. cat_embeddings : nn.ModuleList Module list for categorical feature embeddings. tabular_head : MLP Multi-layer perceptron head for tabular data. cls_token : nn.Parameter Class token parameter. embedding_norm : nn.Module, optional Layer normalization applied after embedding if specified. """ def __init__( self, cat_feature_info, num_feature_info, num_classes=1, config: DefaultTabTransformerConfig = DefaultTabTransformerConfig(), **kwargs, ): super().__init__(**kwargs) self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"]) if cat_feature_info == {}: raise ValueError( "You are trying to fit a TabTransformer with no categorical features. Try using a different model that is better suited for tasks without categorical features." ) layer_norm_dim = 0 for feature_name, input_shape in num_feature_info.items(): layer_norm_dim += input_shape self.lr = self.hparams.get("lr", config.lr) self.lr_patience = self.hparams.get("lr_patience", config.lr_patience) self.weight_decay = self.hparams.get("weight_decay", config.weight_decay) self.lr_factor = self.hparams.get("lr_factor", config.lr_factor) self.pooling_method = self.hparams.get("pooling_method", config.pooling_method) self.cat_feature_info = cat_feature_info self.num_feature_info = num_feature_info encoder_layer = CustomTransformerEncoderLayer( d_model=self.hparams.get("d_model", config.d_model), nhead=self.hparams.get("n_heads", config.n_heads), batch_first=True, dim_feedforward=self.hparams.get( "transformer_dim_feedforward", config.transformer_dim_feedforward ), dropout=self.hparams.get("attn_dropout", config.attn_dropout), activation=self.hparams.get( "transformer_activation", config.transformer_activation ), layer_norm_eps=self.hparams.get("layer_norm_eps", config.layer_norm_eps), norm_first=self.hparams.get("norm_first", config.norm_first), bias=self.hparams.get("bias", config.bias), ) norm_layer = self.hparams.get("norm", config.norm) if norm_layer == "RMSNorm": self.norm_f = RMSNorm(layer_norm_dim) elif norm_layer == "LayerNorm": self.norm_f = LayerNorm(layer_norm_dim) elif norm_layer == "BatchNorm": self.norm_f = BatchNorm(layer_norm_dim) elif norm_layer == "InstanceNorm": self.norm_f = InstanceNorm(layer_norm_dim) elif norm_layer == "GroupNorm": self.norm_f = GroupNorm(1, layer_norm_dim) elif norm_layer == "LearnableLayerScaling": self.norm_f = LearnableLayerScaling(layer_norm_dim) else: self.norm_f = None self.norm_embedding = LayerNorm(self.hparams.get("d_model", config.d_model)) self.encoder = nn.TransformerEncoder( encoder_layer, num_layers=self.hparams.get("n_layers", config.n_layers), norm=self.norm_embedding, ) self.embedding_layer = EmbeddingLayer( num_feature_info=num_feature_info, cat_feature_info=cat_feature_info, d_model=self.hparams.get("d_model", config.d_model), embedding_activation=self.hparams.get( "embedding_activation", config.embedding_activation ), layer_norm_after_embedding=self.hparams.get( "layer_norm_after_embedding", config.layer_norm_after_embedding ), use_cls=True, cls_position=0, cat_encoding=self.hparams.get("cat_encoding", config.cat_encoding), ) head_activation = self.hparams.get("head_activation", config.head_activation) mlp_input_dim = 0 for feature_name, input_shape in num_feature_info.items(): mlp_input_dim += input_shape mlp_input_dim += config.d_model self.tabular_head = MLP( self.hparams.get("d_model", mlp_input_dim), hidden_units_list=self.hparams.get( "head_layer_sizes", config.head_layer_sizes ), dropout_rate=self.hparams.get("head_dropout", config.head_dropout), use_skip_layers=self.hparams.get( "head_skip_layers", config.head_skip_layers ), activation_fn=head_activation, use_batch_norm=self.hparams.get( "head_use_batch_norm", config.head_use_batch_norm ), n_output_units=num_classes, ) self.cls_token = nn.Parameter( torch.zeros(1, 1, self.hparams.get("d_model", config.d_model)) )
[文档] def forward(self, num_features, cat_features): """ Defines the forward pass of the model. Parameters ---------- num_features : Tensor Tensor containing the numerical features. cat_features : Tensor Tensor containing the categorical features. Returns ------- Tensor The output predictions of the model. """ cat_embeddings = self.embedding_layer({}, cat_features) num_features = torch.cat(num_features, dim=1) num_embeddings = self.norm_f(num_features) x = self.encoder(cat_embeddings) if self.pooling_method == "avg": x = torch.mean(x, dim=1) elif self.pooling_method == "max": x, _ = torch.max(x, dim=1) elif self.pooling_method == "sum": x = torch.sum(x, dim=1) elif self.pooling_method == "cls": x = x[:, 0] else: raise ValueError(f"Invalid pooling method: {self.pooling_method}") x = torch.cat((x, num_embeddings), axis=1) preds = self.tabular_head(x) return preds