mambular.base_models.mambular 源代码

import torch
import torch.nn as nn
from ..arch_utils.mamba_arch import Mamba
from ..arch_utils.mlp_utils import MLP
from ..arch_utils.normalization_layers import (
    RMSNorm,
    LayerNorm,
    LearnableLayerScaling,
    BatchNorm,
    InstanceNorm,
    GroupNorm,
)
from ..configs.mambular_config import DefaultMambularConfig
from .basemodel import BaseModel
from ..arch_utils.embedding_layer import EmbeddingLayer


[文档]class Mambular(BaseModel): """ A PyTorch model for tasks utilizing the Mamba architecture and various normalization techniques. Parameters ---------- cat_feature_info : dict Dictionary containing information about categorical features. num_feature_info : dict Dictionary containing information about numerical features. num_classes : int, optional Number of output classes (default is 1). config : DefaultMambularConfig, optional Configuration object containing default hyperparameters for the model (default is DefaultMambularConfig()). **kwargs : dict Additional keyword arguments. Attributes ---------- lr : float Learning rate. lr_patience : int Patience for learning rate scheduler. weight_decay : float Weight decay for optimizer. lr_factor : float Factor by which the learning rate will be reduced. pooling_method : str Method to pool the features. cat_feature_info : dict Dictionary containing information about categorical features. num_feature_info : dict Dictionary containing information about numerical features. embedding_activation : callable Activation function for embeddings. mamba : Mamba Mamba architecture component. norm_f : nn.Module Normalization layer. num_embeddings : nn.ModuleList Module list for numerical feature embeddings. cat_embeddings : nn.ModuleList Module list for categorical feature embeddings. tabular_head : MLP Multi-layer perceptron head for tabular data. cls_token : nn.Parameter Class token parameter. embedding_norm : nn.Module, optional Layer normalization applied after embedding if specified. """ def __init__( self, cat_feature_info, num_feature_info, num_classes=1, config: DefaultMambularConfig = DefaultMambularConfig(), **kwargs, ): super().__init__(**kwargs) self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"]) self.lr = self.hparams.get("lr", config.lr) self.lr_patience = self.hparams.get("lr_patience", config.lr_patience) self.weight_decay = self.hparams.get("weight_decay", config.weight_decay) self.lr_factor = self.hparams.get("lr_factor", config.lr_factor) self.pooling_method = self.hparams.get("pooling_method", config.pooling_method) self.shuffle_embeddings = self.hparams.get( "shuffle_embeddings", config.shuffle_embeddings ) self.cat_feature_info = cat_feature_info self.num_feature_info = num_feature_info self.mamba = Mamba( d_model=self.hparams.get("d_model", config.d_model), n_layers=self.hparams.get("n_layers", config.n_layers), expand_factor=self.hparams.get("expand_factor", config.expand_factor), bias=self.hparams.get("bias", config.bias), d_conv=self.hparams.get("d_conv", config.d_conv), conv_bias=self.hparams.get("conv_bias", config.conv_bias), dropout=self.hparams.get("dropout", config.dropout), dt_rank=self.hparams.get("dt_rank", config.dt_rank), d_state=self.hparams.get("d_state", config.d_state), dt_scale=self.hparams.get("dt_scale", config.dt_scale), dt_init=self.hparams.get("dt_init", config.dt_init), dt_max=self.hparams.get("dt_max", config.dt_max), dt_min=self.hparams.get("dt_min", config.dt_min), dt_init_floor=self.hparams.get("dt_init_floor", config.dt_init_floor), norm=globals()[self.hparams.get("norm", config.norm)], activation=self.hparams.get("activation", config.activation), bidirectional=self.hparams.get("bidiretional", config.bidirectional), use_learnable_interaction=self.hparams.get( "use_learnable_interactions", config.use_learnable_interaction ), AD_weight_decay=self.hparams.get("AB_weight_decay", config.AD_weight_decay), BC_layer_norm=self.hparams.get("AB_layer_norm", config.BC_layer_norm), layer_norm_eps=self.hparams.get("layer_norm_eps", config.layer_norm_eps), ) norm_layer = self.hparams.get("norm", config.norm) if norm_layer == "RMSNorm": self.norm_f = RMSNorm( self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps ) elif norm_layer == "LayerNorm": self.norm_f = LayerNorm( self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps ) elif norm_layer == "BatchNorm": self.norm_f = BatchNorm( self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps ) elif norm_layer == "InstanceNorm": self.norm_f = InstanceNorm( self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps ) elif norm_layer == "GroupNorm": self.norm_f = GroupNorm( 1, self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps, ) elif norm_layer == "LearnableLayerScaling": self.norm_f = LearnableLayerScaling( self.hparams.get("d_model", config.d_model) ) else: raise ValueError(f"Unsupported normalization layer: {norm_layer}") self.embedding_layer = EmbeddingLayer( num_feature_info=num_feature_info, cat_feature_info=cat_feature_info, d_model=self.hparams.get("d_model", config.d_model), embedding_activation=self.hparams.get( "embedding_activation", config.embedding_activation ), layer_norm_after_embedding=self.hparams.get( "layer_norm_after_embedding", config.layer_norm_after_embedding ), use_cls=False, cls_position=-1, cat_encoding=self.hparams.get("cat_encoding", config.cat_encoding), ) head_activation = self.hparams.get("head_activation", config.head_activation) self.tabular_head = MLP( self.hparams.get("d_model", config.d_model), hidden_units_list=self.hparams.get( "head_layer_sizes", config.head_layer_sizes ), dropout_rate=self.hparams.get("head_dropout", config.head_dropout), use_skip_layers=self.hparams.get( "head_skip_layers", config.head_skip_layers ), activation_fn=head_activation, use_batch_norm=self.hparams.get( "head_use_batch_norm", config.head_use_batch_norm ), n_output_units=num_classes, ) if self.pooling_method == "cls": self.use_cls = True else: self.use_cls = self.hparams.get("use_cls", config.use_cls) if self.shuffle_embeddings: self.perm = torch.randperm(self.embedding_layer.seq_len)
[文档] def forward(self, num_features, cat_features): """ Defines the forward pass of the model. Parameters ---------- num_features : Tensor Tensor containing the numerical features. cat_features : Tensor Tensor containing the categorical features. Returns ------- Tensor The output predictions of the model. """ x = self.embedding_layer(num_features, cat_features) if self.shuffle_embeddings: x = x[:, self.perm, :] x = self.mamba(x) if self.pooling_method == "avg": x = torch.mean(x, dim=1) elif self.pooling_method == "max": x, _ = torch.max(x, dim=1) elif self.pooling_method == "sum": x = torch.sum(x, dim=1) elif self.pooling_method == "cls_token": x = x[:, -1] elif self.pooling_method == "last": x = x[:, -1] else: raise ValueError(f"Invalid pooling method: {self.pooling_method}") x = self.norm_f(x) preds = self.tabular_head(x) return preds