mambular.data_utils.dataset 源代码
import numpy as np
import torch
from torch.utils.data import Dataset
[文档]class MambularDataset(Dataset):
"""
Custom dataset for handling structured data with separate categorical and numerical features, tailored for
both regression and classification tasks.
Parameters:
cat_features_list (list of Tensors): A list of tensors representing the categorical features.
num_features_list (list of Tensors): A list of tensors representing the numerical features.
labels (Tensor): A tensor of labels.
regression (bool, optional): A flag indicating if the dataset is for a regression task. Defaults to True.
"""
def __init__(self, cat_features_list, num_features_list, labels, regression=True):
self.cat_features_list = cat_features_list # Categorical features tensors
self.num_features_list = num_features_list # Numerical features tensors
self.regression = regression
if not self.regression:
self.num_classes = len(np.unique(labels))
if self.num_classes > 2:
self.labels = labels.view(-1)
else:
self.num_classes = 1
self.labels = labels
else:
self.labels = labels
self.num_classes = 1
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
"""
Retrieves the features and label for a given index.
Parameters:
idx (int): The index of the data point.
Returns:
tuple: A tuple containing two lists of tensors (one for categorical features and one for numerical
features) and a single label (float if regression is True).
"""
cat_features = [
feature_tensor[idx] for feature_tensor in self.cat_features_list
]
num_features = [
torch.as_tensor(feature_tensor[idx]).clone(
).detach().to(torch.float32)
for feature_tensor in self.num_features_list
]
label = self.labels[idx]
if self.regression:
label = label.clone().detach().to(torch.float32)
elif self.num_classes == 1:
label = label.clone().detach().to(torch.float32)
else:
label = label.clone().detach().to(torch.long)
# Keep categorical and numerical features separate
return cat_features, num_features, label