Skip to content

合成表格数据

已过时

请注意,此示例不适用于最新版本的 ydata-synthetic

请查看 ydata-sdk 以了解如何生成合成数据

使用 WGAN-GP 生成表格合成数据:

现实世界中的领域通常由表格数据描述,即可以结构化并以表格格式组织的数据,其中特征/变量表示为,而观测值对应于

WGANGP 是 GAN 的一种变体,它包含一个梯度惩罚项,以增强训练稳定性和提高生成样本的多样性:

以下是如何使用 Adult Census Income 数据集通过 WGAN-GP 合成表格数据的示例:

from pmlb import fetch_data

from ydata_synthetic.synthesizers.regular import RegularSynthesizer
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters

#Load data and define the data processor parameters
data = fetch_data('adult')
num_cols = ['age', 'fnlwgt', 'capital-gain', 'capital-loss', 'hours-per-week']
cat_cols = ['workclass','education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex',
            'native-country', 'target']

#Defining the training parameters
noise_dim = 128
dim = 128
batch_size = 500

log_step = 100
epochs = 500+1
learning_rate = [5e-4, 3e-3]
beta_1 = 0.5
beta_2 = 0.9
models_dir = '../cache'

gan_args = ModelParameters(batch_size=batch_size,
                           lr=learning_rate,
                           betas=(beta_1, beta_2),
                           noise_dim=noise_dim,
                           layers_dim=dim)

train_args = TrainParameters(epochs=epochs,
                             sample_interval=log_step)

synth = RegularSynthesizer(modelname='wgangp', model_parameters=gan_args, n_critic=2)
synth.fit(data, train_args, num_cols, cat_cols)

synth.save('adult_wgangp_model.pkl')

#########################################################
#    Loading and sampling from a trained synthesizer    #
#########################################################
synth = RegularSynthesizer.load('adult_wgangp_model.pkl')
synth_data = synth.sample(1000)