合成时间序列数据
使用 DoppelGANger 生成合成时间序列数据:
尽管表格数据可能是最常讨论的数据类型,但现实世界中有许多领域——从交通和日常轨迹到股票价格和能源消耗模式——产生了时间序列数据,这为合成数据生成引入了多个复杂方面。
时间序列数据按顺序结构化,观测值根据其关联的时间戳或时间间隔按时间顺序排列。它明确包含了时间维度,允许分析趋势、季节性以及其他随时间变化的依赖关系。
DoppelGANger 是一种使用生成对抗网络(GAN)框架生成合成时间序列数据的模型,通过学习原始数据中的潜在时间依赖关系和特征:
以下是如何使用 DoppelGANger 和 Measuring Broadband America 数据集合成时间序列数据的示例:
"""
DoppelGANger architecture example file
"""
# Importing necessary libraries
import pandas as pd
from os import path
import matplotlib.pyplot as plt
from ydata_synthetic.synthesizers.timeseries import TimeSeriesSynthesizer
from ydata_synthetic.synthesizers import ModelParameters, TrainParameters
# Read the data
mba_data = pd.read_csv("../../data/fcc_mba.csv")
numerical_cols = ["traffic_byte_counter", "ping_loss_rate"]
categorical_cols = [col for col in mba_data.columns if col not in numerical_cols]
# Define model parameters
model_args = ModelParameters(batch_size=100,
lr=0.001,
betas=(0.2, 0.9),
latent_dim=20,
gp_lambda=2,
pac=1)
train_args = TrainParameters(epochs=400, sequence_length=56,
sample_length=8, rounds=1,
measurement_cols=["traffic_byte_counter", "ping_loss_rate"])
# Training the DoppelGANger synthesizer
if path.exists('doppelganger_mba'):
model_dop_gan = TimeSeriesSynthesizer.load('doppelganger_mba')
else:
model_dop_gan = TimeSeriesSynthesizer(modelname='doppelganger', model_parameters=model_args)
model_dop_gan.fit(mba_data, train_args, num_cols=numerical_cols, cat_cols=categorical_cols)
model_dop_gan.save('doppelganger_mba')
# Generate synthetic data
synth_data = model_dop_gan.sample(n_samples=600)
synth_df = pd.concat(synth_data, axis=0)
# Create a plot for each measurement column
plt.figure(figsize=(10, 6))
plt.subplot(2, 1, 1)
plt.plot(mba_data['traffic_byte_counter'].reset_index(drop=True), label='Real Traffic')
plt.plot(synth_df['traffic_byte_counter'].reset_index(drop=True), label='Synthetic Traffic', alpha=0.7)
plt.xlabel('Index')
plt.ylabel('Value')
plt.title('Traffic Comparison')
plt.legend()
plt.grid(True)
plt.subplot(2, 1, 2)
plt.plot(mba_data['ping_loss_rate'].reset_index(drop=True), label='Real Ping')
plt.plot(synth_df['ping_loss_rate'].reset_index(drop=True), label='Synthetic Ping', alpha=0.7)
plt.xlabel('Index')
plt.ylabel('Value')
plt.title('Ping Comparison')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()