Skip to content

合成表格数据

已过时

请注意,此示例不适用于最新版本的 ydata-synthetic

请查看 ydata-sdk 以了解如何生成合成数据

使用 WGAN 生成表格合成数据:

现实世界中的领域通常由表格数据描述,即可以结构化并组织成表格格式的数据,其中特征/变量表示为,而观测值对应于

WGAN 是 GAN 的一种变体,它利用 Wasserstein 距离来提高训练稳定性和生成更高质量的样本:

以下是如何使用 WGAN 和 Credit Card 数据集合成表格数据的示例:

#Install ydata-synthetic lib
# pip install ydata-synthetic
import sklearn.cluster as cluster
import pandas as pd
import numpy as np

from ydata_synthetic.utils.cache import cache_file
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
from ydata_synthetic.synthesizers.regular import RegularSynthesizer

#Read the original data and have it preprocessed
data_path = cache_file('creditcard.csv', 'https://datahub.io/machine-learning/creditcard/r/creditcard.csv')
data = pd.read_csv(data_path, index_col=[0])

#Data processing and analysis
num_cols = list(data.columns[ data.columns != 'Class' ])
cat_cols = ['Class']

print('Dataset columns: {}'.format(num_cols))
sorted_cols = ['V14', 'V4', 'V10', 'V17', 'V12', 'V26', 'Amount', 'V21', 'V8', 'V11', 'V7', 'V28', 'V19', 'V3', 'V22', 'V6', 'V20', 'V27', 'V16', 'V13', 'V25', 'V24', 'V18', 'V2', 'V1', 'V5', 'V15', 'V9', 'V23', 'Class']
processed_data = data[ sorted_cols ].copy()

#For the purpose of this example we will only synthesize the minority class
train_data = processed_data.loc[processed_data['Class'] == 1].copy()

print("Dataset info: Number of records - {} Number of variables - {}".format(train_data.shape[0], train_data.shape[1]))
algorithm = cluster.KMeans
args, kwds = (), {'n_clusters':2, 'random_state':0}
labels = algorithm(*args, **kwds).fit_predict(train_data[ num_cols ])

print( pd.DataFrame( [ [np.sum(labels==i)] for i in np.unique(labels) ], columns=['count'], index=np.unique(labels) ) )

fraud_w_classes = train_data.copy()
fraud_w_classes['Class'] = labels

# GAN training
#Define the GAN and training parameters
noise_dim = 32
dim = 128
batch_size = 128

log_step = 100
epochs = 500+1
learning_rate = 5e-4
beta_1 = 0.5
beta_2 = 0.9
models_dir = '../cache'

model_parameters = ModelParameters(batch_size=batch_size,
                                   lr=learning_rate,
                                   betas=(beta_1, beta_2),
                                   noise_dim=noise_dim,
                                   layers_dim=dim)

train_args = TrainParameters(epochs=epochs,
                             sample_interval=log_step)

test_size = 492 # number of fraud cases
noise_dim = 32

#Training the CRAMERGAN model
synth = RegularSynthesizer(modelname='wgan', model_parameters=model_parameters, n_critic=10)
synth.fit(data=train_data, train_arguments = train_args, num_cols = num_cols, cat_cols = cat_cols)

#Saving the synthesizer to later generate new events
synth.save(path='creditcard_wgan_model.pkl')

#########################################################
#    Loading and sampling from a trained synthesizer    #
#########################################################
synth = RegularSynthesizer.load(path='creditcard_wgan_model.pkl')

#Sampling the data
data_sample = synth.sample(100000)