Skip to content

自监督模型

配置类

Bases: SSLModelConfig

去噪自动编码器配置.

Parameters:

Name Type Description Default
noise_strategy str

定义我们向样本引入什么样的噪声.swap - 交换噪声是指我们用同一特征的随机排列替换特征的值.zero - 零噪声是指我们用零替换特征的值.默认为swap.可选值为: [swap,zero].

'swap'
noise_probabilities Dict[str, float]

用于以交换/零噪声破坏输入特征的个体概率字典.键应为特征名称,如果缺少任何特征,则使用default_noise_probability.默认为空字典()

lambda: {}()
default_noise_probability float

用于以交换/零噪声破坏输入特征的默认概率.对于noise_probabilities未定义概率的特征.默认为0.8

0.8
loss_type_weights Optional[List[float]]

用于损失函数的权重,顺序为[二进制, 分类, 数值].如果为None,将使用默认权重,使用公式计算.例如,对于二进制,默认权重将为n_binary/n_features.默认为None

None
mask_loss_weight float

用于掩码特征损失函数的权重.默认为1.0

2.0
max_onehot_cardinality int

独热编码分类特征的最大基数.任何基数>max_onehot_cardinality的分类特征将在学习的嵌入空间中嵌入,其他特征将转换为独热表示.如果设置为0,将对所有分类特征使用嵌入策略.默认为4

4
include_input_features_inference bool

如果为True,将在微调时包含输入特征以及学习到的特征.默认为False

False
encoder_config Optional[ModelConfig]

用于模型的编码器的配置.应为PyTorch Tabular中定义的模型配置之一

None
decoder_config Optional[ModelConfig]

用于模型的解码器的配置.应为PyTorch Tabular中定义的模型配置之一.默认为nn.Identity

None
embedding_dims Optional[List]

每个分类列的嵌入维度,格式为(基数, 嵌入维度)的列表.如果为空,将根据分类列的基数推断,使用规则min(50, (x + 1) // 2)

None
embedding_dropout float

应用于分类嵌入的丢弃率.默认为0.1

0.1
batch_norm_continuous_input bool

如果为True,我们将通过BatchNorm层对连续层进行归一化.

True
learning_rate float

模型的学习率.默认为1e-3

0.001
seed int

用于可重复性的种子.默认为42

42
Source code in src/pytorch_tabular/ssl_models/dae/config.py
@dataclass
class DenoisingAutoEncoderConfig(SSLModelConfig):
    """去噪自动编码器配置.

    Parameters:
        noise_strategy (str): 定义我们向样本引入什么样的噪声.`swap` - 交换噪声是指我们用同一特征的随机排列替换特征的值.`zero` - 零噪声是指我们用零替换特征的值.默认为swap.可选值为: [`swap`,`zero`].

        noise_probabilities (Dict[str, float]): 用于以交换/零噪声破坏输入特征的个体概率字典.键应为特征名称,如果缺少任何特征,则使用default_noise_probability.默认为空字典()

        default_noise_probability (float): 用于以交换/零噪声破坏输入特征的默认概率.对于noise_probabilities未定义概率的特征.默认为0.8

        loss_type_weights (Optional[List[float]]): 用于损失函数的权重,顺序为[二进制, 分类, 数值].如果为None,将使用默认权重,使用公式计算.例如,对于二进制,默认权重将为n_binary/n_features.默认为None

        mask_loss_weight (float): 用于掩码特征损失函数的权重.默认为1.0

        max_onehot_cardinality (int): 独热编码分类特征的最大基数.任何基数>max_onehot_cardinality的分类特征将在学习的嵌入空间中嵌入,其他特征将转换为独热表示.如果设置为0,将对所有分类特征使用嵌入策略.默认为4

        include_input_features_inference (bool): 如果为True,将在微调时包含输入特征以及学习到的特征.默认为False

        encoder_config (Optional[pytorch_tabular.config.config.ModelConfig]): 用于模型的编码器的配置.应为PyTorch Tabular中定义的模型配置之一

        decoder_config (Optional[pytorch_tabular.config.config.ModelConfig]): 用于模型的解码器的配置.应为PyTorch Tabular中定义的模型配置之一.默认为nn.Identity

        embedding_dims (Optional[List]): 每个分类列的嵌入维度,格式为(基数, 嵌入维度)的列表.如果为空,将根据分类列的基数推断,使用规则min(50, (x + 1) // 2)

        embedding_dropout (float): 应用于分类嵌入的丢弃率.默认为0.1

        batch_norm_continuous_input (bool): 如果为True,我们将通过BatchNorm层对连续层进行归一化.

        learning_rate (float): 模型的学习率.默认为1e-3

        seed (int): 用于可重复性的种子.默认为42"""

    noise_strategy: str = field(
        default="swap",
        metadata={
            "help": "Defines what kind of noise we are introducing to samples."
            " `swap` - Swap noise is when we replace values of a feature with random permutations"
            " of the same feature. `zero` - Zero noise is when we replace values of a feature with zeros."
            " Defaults to swap",
            "choices": ["swap", "zero"],
        },
    )
    # Union not supported by omegaconf. Currently Union[float, Dict[str, float]]
    noise_probabilities: Dict[str, float] = field(
        default_factory=lambda: {},
        metadata={
            "help": "Dict of individual probabilities to corrupt the input features with swap/zero noise."
            " Key should be the feature name and if any feature is missing,"
            " the default_noise_probability is used. Default is an empty dict()"
        },
    )
    default_noise_probability: float = field(
        default=0.8,
        metadata={
            "help": "Default probability to corrupt the input features with swap/zero noise."
            " For features for which noise_probabilities does not define a probability. Default is 0.8"
        },
    )
    loss_type_weights: Optional[List[float]] = field(
        default=None,
        metadata={
            "help": "Weights to be used for the loss function in the order [binary, categorical, numerical]."
            " If None, will use the default weights using a formula. eg. for binary,"
            " default weight will be n_binary/n_features. Defaults to None"
        },
    )
    mask_loss_weight: float = field(
        default=2.0,
        metadata={"help": "Weight to be used for the loss function for the masked features. Defaults to 1.0"},
    )
    max_onehot_cardinality: int = field(
        default=4,
        metadata={
            "help": "Maximum cardinality of one-hot encoded categorical features."
            " Any categorical feature with cardinality>max_onehot_cardinality will be embedded"
            " in a learned embedding space and others will be converted to a one hot representation."
            " If set to 0, will use the embedding strategy for all categorical feature. Default is 4"
        },
    )
    include_input_features_inference: bool = field(
        default=False,
        metadata={
            "help": "If True, will include the input features along with the learned features"
            " while fine tuning. Defaults to False"
        },
    )

    _module_src: str = field(default="ssl_models.dae")
    _model_name: str = field(default="DenoisingAutoEncoderModel")
    _config_name: str = field(default="DenoisingAutoEncoderConfig")

    def __post_init__(self):
        assert hasattr(self.encoder_config, "_backbone_name"), "encoder_config should have a _backbone_name attribute"
        if self.decoder_config is not None:
            assert hasattr(
                self.decoder_config, "_backbone_name"
            ), "decoder_config should have a _backbone_name attribute"
        super().__post_init__()

模型类

Bases: SSLBaseModel

Source code in src/pytorch_tabular/ssl_models/dae/dae.py
class DenoisingAutoEncoderModel(SSLBaseModel):
    output_tuple = namedtuple("output_tuple", ["original", "reconstructed"])
    loss_weight_tuple = namedtuple("loss_weight_tuple", ["binary", "categorical", "continuous", "mask"])
    # fix for pickling
    # https://codefying.com/2019/05/04/dont-get-in-a-pickle-with-a-namedtuple/
    output_tuple.__qualname__ = "DenoisingAutoEncoderModel.output_tuple"
    loss_weight_tuple.__qualname__ = "DenoisingAutoEncoderModel.loss_weight_tuple"
    ALLOWED_MODELS = ["CategoryEmbeddingModelConfig"]

    def __init__(self, config: DictConfig, **kwargs):
        encoded_cat_dims = 0
        inferred_config = kwargs.get("inferred_config")
        for card, embd_dim in inferred_config.embedding_dims:
            if card == 2:
                encoded_cat_dims += 1
            elif card <= config.max_onehot_cardinality:
                encoded_cat_dims += card
            else:
                encoded_cat_dims += embd_dim
        config.encoder_config._backbone_input_dim = encoded_cat_dims + len(config.continuous_cols)
        assert config.encoder_config._config_name in self.ALLOWED_MODELS, (
            "Encoder must be one of the following: " + ", ".join(self.ALLOWED_MODELS)
        )
        if config.decoder_config is not None:
            assert config.decoder_config._config_name in self.ALLOWED_MODELS, (
                "Decoder must be one of the following: " + ", ".join(self.ALLOWED_MODELS)
            )
            if "-" in config.encoder_config.layers:
                config.decoder_config._backbone_input_dim = int(config.encoder_config.layers.split("-")[-1])
            else:
                config.decoder_config._backbone_input_dim = int(config.encoder_config.layers)
        super().__init__(config, **kwargs)

    def _get_noise_probability(self, name):
        return self.hparams.noise_probabilities.get(name, self.hparams.default_noise_probability)

    @property
    def embedding_layer(self):
        return self._embedding

    @property
    def featurizer(self):
        return self._featurizer

    def _build_network(self):
        self._featurizer = DenoisingAutoEncoderFeaturizer(self.encoder, self.hparams)
        self._embedding = self._featurizer._build_embedding_layer()
        self.reconstruction = MultiTaskHead(
            self.decoder.output_dim,
            n_binary=len(self._embedding._binary_feat_idx),
            n_categorical=len(self._embedding._onehot_feat_idx),
            n_numerical=self._embedding.embedded_cat_dim + len(self.hparams.continuous_cols),
            cardinality=[self._embedding.categorical_embedding_dims[i][0] for i in self._embedding._onehot_feat_idx],
        )
        self.mask_reconstruction = nn.Linear(self.decoder.output_dim, len(self._featurizer.swap_noise.probas))

    def _setup_loss(self):
        self.losses = {
            "binary": nn.BCEWithLogitsLoss(),
            "categorical": nn.CrossEntropyLoss(),
            "continuous": nn.MSELoss(),
            "mask": nn.BCEWithLogitsLoss(),
        }
        if self.hparams.loss_type_weights is None:
            self.loss_weights = self.loss_weight_tuple(*self._init_loss_weights())
        else:
            self.loss_weights = self.loss_weight_tuple(*self.hparams.loss_type_weights, self.hparams.mask_loss_weight)

    def _init_loss_weights(self):
        n_features = self.hparams.continuous_dim + len(self.hparams.embedding_dims)
        return [
            len(self.embedding_layer._binary_feat_idx) / n_features,
            len(self.embedding_layer._onehot_feat_idx) / n_features,
            self.hparams.continuous_dim + len(self.embedding_layer._embedding_feat_idx) / n_features,
            self.hparams.mask_loss_weight,
        ]

    def _setup_metrics(self):
        return None

    def forward(self, x: Dict):
        if self.mode == "pretrain":
            x = self.embedding_layer(x)
            # (B, N, E)
            features = self.featurizer(x, perturb=True)
            z, mask = features.features, features.mask
            # decoder
            z_hat = self.decoder(z)
            # reconstruction
            reconstructed_in = self.reconstruction(z_hat)
            # mask reconstruction
            reconstructed_mask = self.mask_reconstruction(z_hat)
            output_dict = {"mask": self.output_tuple(mask, reconstructed_mask)}
            if "continuous" in reconstructed_in.keys():
                output_dict["continuous"] = self.output_tuple(
                    torch.cat(
                        [
                            i
                            for i in [
                                x.get("continuous", None),
                                x.get("embedding", None),
                            ]
                            if i is not None
                        ],
                        1,
                    ),
                    reconstructed_in["continuous"],
                )
            if "categorical" in reconstructed_in.keys():
                output_dict["categorical"] = self.output_tuple(x["_categorical_orig"], reconstructed_in["categorical"])
            if "binary" in reconstructed_in.keys():
                output_dict["binary"] = self.output_tuple(x["binary"], reconstructed_in["binary"])
            return output_dict
        else:  # self.mode == "finetune"
            z, x = self.featurizer(x, perturb=False, return_input=True)
            if self.hparams.include_input_features_inference:
                return torch.cat([z.features, x], 1)
            else:
                return z.features

    def calculate_loss(self, output, tag):
        total_loss = 0
        for type_, out in output.items():
            if type_ == "categorical":
                loss = 0
                for i in range(out.original.size(-1)):
                    loss += self.losses[type_](out.reconstructed[i], out.original[:, i])
            elif type_ == "binary":
                # Casting output to float for BCEWithLogitsLoss
                loss = self.losses[type_](out.reconstructed, out.original.float())
            else:
                loss = self.losses[type_](out.reconstructed, out.original)
            loss *= getattr(self.loss_weights, type_)
            self.log(
                f"{tag}_{type_}_loss",
                loss.item(),
                on_epoch=True,
                on_step=False,
                logger=True,
                prog_bar=False,
            )
            total_loss += loss
        self.log(
            f"{tag}_loss",
            total_loss,
            on_epoch=(tag == "valid") or (tag == "test"),
            on_step=(tag == "train"),
            # on_step=False,
            logger=True,
            prog_bar=True,
        )
        return total_loss

    def calculate_metrics(self, output, tag):
        pass

    def featurize(self, x: Dict):
        x = self.embedding_layer(x)
        return self.featurizer(x, perturb=False).features

    @property
    def output_dim(self):
        if self.mode == "finetune" and self.hparams.include_input_features_inference:
            return self._featurizer.encoder.output_dim + self.hparams.encoder_config._backbone_input_dim
        else:
            return self._featurizer.encoder.output_dim

基础模型类

Bases: LightningModule

Source code in src/pytorch_tabular/ssl_models/base_model.py
class SSLBaseModel(pl.LightningModule, metaclass=ABCMeta):
    def __init__(
        self,
        config: DictConfig,
        mode: str = "pretrain",
        encoder: Optional[nn.Module] = None,
        decoder: Optional[nn.Module] = None,
        custom_optimizer: Optional[torch.optim.Optimizer] = None,
        custom_optimizer_params: Dict = {},
        **kwargs,
    ):
        """   所有SSL模型的基础模型.

Parameters:
    config (DictConfig): 用户定义的配置
    mode (str, 可选): 模型的模式.默认为 "pretrain".
    encoder (Optional[nn.Module], 可选): 模型的编码器.默认为 None.
    decoder (Optional[nn.Module], 可选): 模型的解码器.默认为 None.
    custom_optimizer (Optional[torch.optim.Optimizer], 可选): 要使用的自定义优化器.默认为 None.
    custom_optimizer_params (Dict, 可选): 要使用的自定义优化器参数.默认为 {}.
"""
        super().__init__()
        assert "inferred_config" in kwargs, "inferred_config not found in initialization arguments"
        inferred_config = kwargs["inferred_config"]
        # Merging the config and inferred config
        config = safe_merge_config(config, inferred_config)

        self._setup_encoder_decoder(
            encoder,
            config.encoder_config,
            decoder,
            config.decoder_config,
            inferred_config,
        )
        self.custom_optimizer = custom_optimizer
        self.custom_optimizer_params = custom_optimizer_params
        # Updating config with custom parameters for experiment tracking
        if self.custom_optimizer is not None:
            config.optimizer = str(self.custom_optimizer.__class__.__name__)
        if len(self.custom_optimizer_params) > 0:
            config.optimizer_params = self.custom_optimizer_params
        self.mode = mode
        self._check_and_verify()
        self.save_hyperparameters(config)
        self._build_network()
        self._setup_loss()
        self._setup_metrics()

    def _setup_encoder_decoder(self, encoder, encoder_config, decoder, decoder_config, inferred_config):
        assert (encoder is not None) or (
            encoder_config is not None
        ), "Either encoder or encoder_config must be provided"
        # assert (decoder is not None) or (decoder_config is not None),
        # "Either decoder or decoder_config must be provided"
        if encoder is not None:
            self.encoder = encoder
            self._custom_decoder = True
        else:
            # Since encoder is not provided, we will use the encoder_config
            model_callable = getattr_nested(encoder_config._module_src, encoder_config._backbone_name)
            self.encoder = model_callable(
                safe_merge_config(encoder_config, inferred_config),
                # inferred_config=inferred_config,
            )
        if decoder is not None:
            self.decoder = decoder
            self._custom_encoder = True
        elif decoder_config is not None:
            # Since decoder is not provided, we will use the decoder_config
            model_callable = getattr_nested(decoder_config._module_src, decoder_config._backbone_name)
            self.decoder = model_callable(
                safe_merge_config(decoder_config, inferred_config),
                # inferred_config=inferred_config,
            )
        else:
            self.decoder = nn.Identity()

    def _check_and_verify(self):
        assert hasattr(self.encoder, "output_dim"), "An encoder backbone must have an output_dim attribute"
        if isinstance(self.decoder, nn.Identity):
            self.decoder.output_dim = self.encoder.output_dim
        assert hasattr(self.decoder, "output_dim"), "A decoder must have an output_dim attribute"

    @property
    def embedding_layer(self):
        raise NotImplementedError("`embedding_layer` property needs to be implemented by inheriting classes")

    @property
    def featurizer(self):
        raise NotImplementedError("`featurizer` property needs to be implemented by inheriting classes")

    @abstractmethod
    def _setup_loss(self):
        pass

    @abstractmethod
    def _setup_metrics(self):
        pass

    @abstractmethod
    def calculate_loss(self, output, tag):
        pass

    @abstractmethod
    def calculate_metrics(self, output, tag):
        pass

    @abstractmethod
    def forward(self, x: Dict):
        pass

    @abstractmethod
    def featurize(self, x: Dict):
        pass

    def predict(self, x: Dict, ret_model_output: bool = True):  # ret_model_output only for compatibility
        assert ret_model_output, "ret_model_output must be True in case of SSL predict"
        return self.featurize(x)

    def data_aware_initialization(self, datamodule):
        pass

    def training_step(self, batch, batch_idx):
        output = self.forward(batch)
        loss = self.calculate_loss(output, tag="train")
        self.calculate_metrics(output, tag="train")
        return loss

    def validation_step(self, batch, batch_idx):
        with torch.no_grad():
            output = self.forward(batch)
            self.calculate_loss(output, tag="valid")
            self.calculate_metrics(output, tag="valid")
        return output

    def test_step(self, batch, batch_idx):
        with torch.no_grad():
            output = self.forward(batch)
            self.calculate_loss(output, tag="test")
            self.calculate_metrics(output, tag="test")
        return output

    def on_validation_epoch_end(self) -> None:
        if hasattr(self.hparams, "log_logits") and self.hparams.log_logits:
            warnings.warn(
                "Logging Logits is disabled for SSL tasks. Set `log_logits` to False" " to turn off this warning"
            )
        super().on_validation_epoch_end()

    def configure_optimizers(self):
        if self.custom_optimizer is None:
            # Loading from the config
            try:
                self._optimizer = _create_optimizer(self.hparams.optimizer)
                opt = self._optimizer(
                    self.parameters(),
                    lr=self.hparams.learning_rate,
                    **self.hparams.optimizer_params,
                )
            except AttributeError as e:
                logger.error(f"{self.hparams.optimizer} is not a valid optimizer defined in the torch.optim module")
                raise e
        else:
            # Loading from custom fit arguments
            self._optimizer = self.custom_optimizer

            opt = self._optimizer(self.parameters(), lr=self.hparams.learning_rate, **self.custom_optimizer_params)
        if self.hparams.lr_scheduler is not None:
            try:
                self._lr_scheduler = getattr(torch.optim.lr_scheduler, self.hparams.lr_scheduler)
            except AttributeError as e:
                logger.error(
                    f"{self.hparams.lr_scheduler} is not a valid learning rate sheduler defined"
                    f" in the torch.optim.lr_scheduler module"
                )
                raise e
            if isinstance(self._lr_scheduler, torch.optim.lr_scheduler._LRScheduler):
                return {
                    "optimizer": opt,
                    "lr_scheduler": self._lr_scheduler(opt, **self.hparams.lr_scheduler_params),
                }
            return {
                "optimizer": opt,
                "lr_scheduler": self._lr_scheduler(opt, **self.hparams.lr_scheduler_params),
                "monitor": self.hparams.lr_scheduler_monitor_metric,
            }
        else:
            return opt

    def reset_weights(self):
        reset_all_weights(self.featurizer)
        reset_all_weights(self.embedding_layer)

__init__(config, mode='pretrain', encoder=None, decoder=None, custom_optimizer=None, custom_optimizer_params={}, **kwargs)

所有SSL模型的基础模型.

Parameters:

Name Type Description Default
config DictConfig

用户定义的配置

required
mode (str, 可选)

模型的模式.默认为 "pretrain".

'pretrain'
encoder (Optional[Module], 可选)

模型的编码器.默认为 None.

None
decoder (Optional[Module], 可选)

模型的解码器.默认为 None.

None
custom_optimizer (Optional[Optimizer], 可选)

要使用的自定义优化器.默认为 None.

None
custom_optimizer_params (Dict, 可选)

要使用的自定义优化器参数.默认为 {}.

{}
Source code in src/pytorch_tabular/ssl_models/base_model.py
    def __init__(
        self,
        config: DictConfig,
        mode: str = "pretrain",
        encoder: Optional[nn.Module] = None,
        decoder: Optional[nn.Module] = None,
        custom_optimizer: Optional[torch.optim.Optimizer] = None,
        custom_optimizer_params: Dict = {},
        **kwargs,
    ):
        """   所有SSL模型的基础模型.

Parameters:
    config (DictConfig): 用户定义的配置
    mode (str, 可选): 模型的模式.默认为 "pretrain".
    encoder (Optional[nn.Module], 可选): 模型的编码器.默认为 None.
    decoder (Optional[nn.Module], 可选): 模型的解码器.默认为 None.
    custom_optimizer (Optional[torch.optim.Optimizer], 可选): 要使用的自定义优化器.默认为 None.
    custom_optimizer_params (Dict, 可选): 要使用的自定义优化器参数.默认为 {}.
"""
        super().__init__()
        assert "inferred_config" in kwargs, "inferred_config not found in initialization arguments"
        inferred_config = kwargs["inferred_config"]
        # Merging the config and inferred config
        config = safe_merge_config(config, inferred_config)

        self._setup_encoder_decoder(
            encoder,
            config.encoder_config,
            decoder,
            config.decoder_config,
            inferred_config,
        )
        self.custom_optimizer = custom_optimizer
        self.custom_optimizer_params = custom_optimizer_params
        # Updating config with custom parameters for experiment tracking
        if self.custom_optimizer is not None:
            config.optimizer = str(self.custom_optimizer.__class__.__name__)
        if len(self.custom_optimizer_params) > 0:
            config.optimizer_params = self.custom_optimizer_params
        self.mode = mode
        self._check_and_verify()
        self.save_hyperparameters(config)
        self._build_network()
        self._setup_loss()
        self._setup_metrics()