代码示例 / 音频数据 / 基于MelGAN的声谱图反演使用特征匹配

基于MelGAN的声谱图反演使用特征匹配

作者: Darshan Deshpande
创建日期: 02/09/2021
最后修改: 15/09/2021

在Colab中查看 GitHub源代码

描述: 使用MelGAN架构和特征匹配从mel声谱图反演音频。


介绍

自回归声码器在语音处理历史的大部分时间里一直广泛使用,但在大多数时间里缺乏并行性。MelGAN是一种非自回归、完全卷积的声码器架构,用于从光谱反演和语音增强到当前先进的语音合成等目的,当与将文本转换为mel声谱图的模型(如Tacotron2或FastSpeech)一起用作解码器时。

在本教程中,我们将研究MelGAN架构以及它如何实现快速的光谱反演,即将声谱图转换为音频波。本文教程中实现的MelGAN与原始实现相似,唯一的区别是卷积的填充方法,我们将使用“same”而非反射填充。


导入和定义超参数

!pip install -qqq tensorflow_addons
!pip install -qqq tensorflow-io
import tensorflow as tf
import tensorflow_io as tfio
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow_addons import layers as addon_layers

# 设置日志记录级别以避免输入形状警告
tf.get_logger().setLevel("ERROR")

# 定义超参数

DESIRED_SAMPLES = 8192
LEARNING_RATE_GEN = 1e-5
LEARNING_RATE_DISC = 1e-6
BATCH_SIZE = 16

mse = keras.losses.MeanSquaredError()
mae = keras.losses.MeanAbsoluteError()
|████████████████████████████████| 1.1 MB 5.1 MB/s 
|████████████████████████████████| 22.7 MB 1.7 MB/s 
|████████████████████████████████| 2.1 MB 36.2 MB/s 

加载数据集

此示例使用LJSpeech数据集

LJSpeech数据集主要用于文本到语音,包含来自7本非小说类书籍的13,100个离散语音样本,总长度约为24小时。MelGAN的训练只关注音频波形,因此我们仅处理WAV文件并忽略音频注释。

!wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
!tar -xf /content/LJSpeech-1.1.tar.bz2
--2021-09-16 11:45:24--  https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
正在解析 data.keithito.com (data.keithito.com)... 174.138.79.61
连接到 data.keithito.com (data.keithito.com)|174.138.79.61|:443... 已连接。
发送HTTP请求,等待响应... 200 OK
长度: 2748572632 (2.6G) [application/octet-stream]
正在保存到: ‘LJSpeech-1.1.tar.bz2’
LJSpeech-1.1.tar.bz 100%[===================>]   2.56G  68.3MB/s    in 36s     
2021-09-16 11:46:01 (72.2 MB/s) - ‘LJSpeech-1.1.tar.bz2’ saved [2748572632/2748572632]

我们创建了一个tf.data.Dataset来加载和处理音频文件。preprocess()函数以文件路径作为输入,并返回两个波形实例,一个作为输入,另一个作为比较的真实值。输入波形将通过自定义的MelSpec层映射为声谱图,如下面示例所示。

# 将数据集拆分为训练和测试集
wavs = tf.io.gfile.glob("LJSpeech-1.1/wavs/*.wav")
print(f"音频文件数量: {len(wavs)}")

# 加载音频的映射函数。该函数返回两个波形实例
def preprocess(filename):
    audio = tf.audio.decode_wav(tf.io.read_file(filename), 1, DESIRED_SAMPLES).audio
    return audio, audio

# 创建tf.data.Dataset对象并应用预处理
train_dataset = tf.data.Dataset.from_tensor_slices((wavs,))
train_dataset = train_dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
音频文件数量: 13100

定义MelGAN的自定义层

MelGAN架构由3个主要模块组成:

  1. 残差块
  2. 膨胀卷积块
  3. 判别器块

MelGAN 由于网络以梅尔谱图作为输入,我们将创建一个额外的自定义层,该层可以实时将原始音频波形转换为谱图。我们使用来自 train_dataset 的原始音频张量,并使用下面的 MelSpec 层将其映射到梅尔谱图。

# 自定义的keras层,用于实时音频到谱图的转换


class MelSpec(layers.Layer):
    def __init__(
        self,
        frame_length=1024,
        frame_step=256,
        fft_length=None,
        sampling_rate=22050,
        num_mel_channels=80,
        freq_min=125,
        freq_max=7600,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.frame_length = frame_length
        self.frame_step = frame_step
        self.fft_length = fft_length
        self.sampling_rate = sampling_rate
        self.num_mel_channels = num_mel_channels
        self.freq_min = freq_min
        self.freq_max = freq_max
        # 定义梅尔滤波器。此滤波器将与STFT输出相乘
        self.mel_filterbank = tf.signal.linear_to_mel_weight_matrix(
            num_mel_bins=self.num_mel_channels,
            num_spectrogram_bins=self.frame_length // 2 + 1,
            sample_rate=self.sampling_rate,
            lower_edge_hertz=self.freq_min,
            upper_edge_hertz=self.freq_max,
        )

    def call(self, audio, training=True):
        # 我们将在训练期间执行转换。
        if training:
            # 进行短时傅里叶变换。确保音频已填充。
            # 在论文中,STFT输出使用“反射”策略进行填充。
            stft = tf.signal.stft(
                tf.squeeze(audio, -1),
                self.frame_length,
                self.frame_step,
                self.fft_length,
                pad_end=True,
            )

            # 获取STFT输出的幅度
            magnitude = tf.abs(stft)

            # 将梅尔滤波器与幅度相乘,并使用分贝刻度进行缩放
            mel = tf.matmul(tf.square(magnitude), self.mel_filterbank)
            log_mel_spec = tfio.audio.dbscale(mel, top_db=80)
            return log_mel_spec
        else:
            return audio

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "frame_length": self.frame_length,
                "frame_step": self.frame_step,
                "fft_length": self.fft_length,
                "sampling_rate": self.sampling_rate,
                "num_mel_channels": self.num_mel_channels,
                "freq_min": self.freq_min,
                "freq_max": self.freq_max,
            }
        )
        return config

残差卷积块广泛使用膨胀并且每个块的总接收场为27个时间步骤。膨胀必须作为kernel_size的幂增大,以确保输出中嘶嘶声的减少。论文中提出的网络如下:

ConvBlock

# 创建残差堆叠块


def residual_stack(input, filters):
    """具有权重归一化的卷积残差堆叠。

    Args:
        filters: int,确定残差堆叠的滤波器大小。

    Returns:
        残差堆叠输出。
    """
    c1 = addon_layers.WeightNormalization(
        layers.Conv1D(filters, 3, dilation_rate=1, padding="same"), data_init=False
    )(input)
    lrelu1 = layers.LeakyReLU()(c1)
    c2 = addon_layers.WeightNormalization(
        layers.Conv1D(filters, 3, dilation_rate=1, padding="same"), data_init=False
    )(lrelu1)
    add1 = layers.Add()([c2, input])

    lrelu2 = layers.LeakyReLU()(add1)
    c3 = addon_layers.WeightNormalization(
        layers.Conv1D(filters, 3, dilation_rate=3, padding="same"), data_init=False
    )(lrelu2)
    lrelu3 = layers.LeakyReLU()(c3)
    c4 = addon_layers.WeightNormalization(
        layers.Conv1D(filters, 3, dilation_rate=1, padding="same"), data_init=False
    )(lrelu3)
    add2 = layers.Add()([add1, c4])

    lrelu4 = layers.LeakyReLU()(add2)
    c5 = addon_layers.WeightNormalization(
        layers.Conv1D(filters, 3, dilation_rate=9, padding="same"), data_init=False
    )(lrelu4)
    lrelu5 = layers.LeakyReLU()(c5)
    c6 = addon_layers.WeightNormalization(
        layers.Conv1D(filters, 3, dilation_rate=1, padding="same"), data_init=False
    )(lrelu5)
    add3 = layers.Add()([c6, add2])

    return add3

每个卷积块使用残差堆栈提供的膨胀,并通过upsampling_factor进行上采样输入数据。

# 由残差堆叠构成的扩张卷积块


def conv_block(input, conv_dim, upsampling_factor):
    """带有权重归一化的扩张卷积块。

    Args:
        conv_dim: int, 确定块的滤波器大小。
        upsampling_factor: int, 上采样的缩放因子。

    Returns:
        扩张卷积块。
    """
    conv_t = addon_layers.WeightNormalization(
        layers.Conv1DTranspose(conv_dim, 16, upsampling_factor, padding="same"),
        data_init=False,
    )(input)
    lrelu1 = layers.LeakyReLU()(conv_t)
    res_stack = residual_stack(lrelu1, conv_dim)
    lrelu2 = layers.LeakyReLU()(res_stack)
    return lrelu2

鉴别器块由卷积层和下采样层组成。该块对于特征匹配技术的实现至关重要。

每个鉴别器输出一个特征图列表,这些特征图将在训练期间进行比较以计算特征匹配损失。

def discriminator_block(input):
    conv1 = addon_layers.WeightNormalization(
        layers.Conv1D(16, 15, 1, "same"), data_init=False
    )(input)
    lrelu1 = layers.LeakyReLU()(conv1)
    conv2 = addon_layers.WeightNormalization(
        layers.Conv1D(64, 41, 4, "same", groups=4), data_init=False
    )(lrelu1)
    lrelu2 = layers.LeakyReLU()(conv2)
    conv3 = addon_layers.WeightNormalization(
        layers.Conv1D(256, 41, 4, "same", groups=16), data_init=False
    )(lrelu2)
    lrelu3 = layers.LeakyReLU()(conv3)
    conv4 = addon_layers.WeightNormalization(
        layers.Conv1D(1024, 41, 4, "same", groups=64), data_init=False
    )(lrelu3)
    lrelu4 = layers.LeakyReLU()(conv4)
    conv5 = addon_layers.WeightNormalization(
        layers.Conv1D(1024, 41, 4, "same", groups=256), data_init=False
    )(lrelu4)
    lrelu5 = layers.LeakyReLU()(conv5)
    conv6 = addon_layers.WeightNormalization(
        layers.Conv1D(1024, 5, 1, "same"), data_init=False
    )(lrelu5)
    lrelu6 = layers.LeakyReLU()(conv6)
    conv7 = addon_layers.WeightNormalization(
        layers.Conv1D(1, 3, 1, "same"), data_init=False
    )(lrelu6)
    return [lrelu1, lrelu2, lrelu3, lrelu4, lrelu5, lrelu6, conv7]

创建生成器

def create_generator(input_shape):
    inp = keras.Input(input_shape)
    x = MelSpec()(inp)
    x = layers.Conv1D(512, 7, padding="same")(x)
    x = layers.LeakyReLU()(x)
    x = conv_block(x, 256, 8)
    x = conv_block(x, 128, 8)
    x = conv_block(x, 64, 2)
    x = conv_block(x, 32, 2)
    x = addon_layers.WeightNormalization(
        layers.Conv1D(1, 7, padding="same", activation="tanh")
    )(x)
    return keras.Model(inp, x)


# 我们为生成器使用动态输入形状,因为模型是完全卷积的
generator = create_generator((None, 1))
generator.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, None, 1)]    0                                            
__________________________________________________________________________________________________
mel_spec (MelSpec)              (None, None, 80)     0           input_1[0][0]                    
__________________________________________________________________________________________________
conv1d (Conv1D)                 (None, None, 512)    287232      mel_spec[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu (LeakyReLU)         (None, None, 512)    0           conv1d[0][0]                     
__________________________________________________________________________________________________
weight_normalization (WeightNor (None, None, 256)    2097921     leaky_re_lu[0][0]                
__________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU)       (None, None, 256)    0           weight_normalization[0][0]       
__________________________________________________________________________________________________
weight_normalization_1 (WeightN (None, None, 256)    197121      leaky_re_lu_1[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_2 (LeakyReLU)       (None, None, 256)    0           weight_normalization_1[0][0]     
__________________________________________________________________________________________________
weight_normalization_2 (WeightN (None, None, 256)    197121      leaky_re_lu_2[0][0]              
__________________________________________________________________________________________________
add (Add)                       (None, None, 256)    0           weight_normalization_2[0][0]     
                                                                 leaky_re_lu_1[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_3 (LeakyReLU)       (None, None, 256)    0           add[0][0]                        
__________________________________________________________________________________________________
weight_normalization_3 (WeightN (None, None, 256)    197121      leaky_re_lu_3[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_4 (LeakyReLU)       (None, None, 256)    0           weight_normalization_3[0][0]     
__________________________________________________________________________________________________
weight_normalization_4 (WeightN (None, None, 256)    197121      leaky_re_lu_4[0][0]              
__________________________________________________________________________________________________
add_1 (Add)                     (None, None, 256)    0           add[0][0]                        
                                                                 weight_normalization_4[0][0]     
__________________________________________________________________________________________________
leaky_re_lu_5 (LeakyReLU)       (None, None, 256)    0           add_1[0][0]                      
__________________________________________________________________________________________________
weight_normalization_5 (WeightN (None, None, 256)    197121      leaky_re_lu_5[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_6 (LeakyReLU)       (None, None, 256)    0           weight_normalization_5[0][0]     
__________________________________________________________________________________________________
weight_normalization_6 (WeightN (None, None, 256)    197121      leaky_re_lu_6[0][0]              
__________________________________________________________________________________________________
add_2 (Add)                     (None, None, 256)    0           weight_normalization_6[0][0]     
                                                                 add_1[0][0]                      
__________________________________________________________________________________________________
leaky_re_lu_7 (LeakyReLU)       (None, None, 256)    0           add_2[0][0]                      
__________________________________________________________________________________________________
weight_normalization_7 (WeightN (None, None, 128)    524673      leaky_re_lu_7[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_8 (LeakyReLU)       (None, None, 128)    0           weight_normalization_7[0][0]     
__________________________________________________________________________________________________
weight_normalization_8 (WeightN (None, None, 128)    49409       leaky_re_lu_8[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_9 (LeakyReLU)       (None, None, 128)    0           weight_normalization_8[0][0]     
__________________________________________________________________________________________________
weight_normalization_9 (WeightN (None, None, 128)    49409       leaky_re_lu_9[0][0]              
__________________________________________________________________________________________________
add_3 (Add)                     (None, None, 128)    0           weight_normalization_9[0][0]     
                                                                 leaky_re_lu_8[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_10 (LeakyReLU)      (None, None, 128)    0           add_3[0][0]                      
__________________________________________________________________________________________________
weight_normalization_10 (Weight (None, None, 128)    49409       leaky_re_lu_10[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_11 (LeakyReLU)      (None, None, 128)    0           weight_normalization_10[0][0]    
__________________________________________________________________________________________________
weight_normalization_11 (Weight (None, None, 128)    49409       leaky_re_lu_11[0][0]             
__________________________________________________________________________________________________
add_4 (Add)                     (None, None, 128)    0           add_3[0][0]                      
                                                                 weight_normalization_11[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_12 (LeakyReLU)      (None, None, 128)    0           add_4[0][0]                      
__________________________________________________________________________________________________
weight_normalization_12 (Weight (None, None, 128)    49409       leaky_re_lu_12[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_13 (LeakyReLU)      (None, None, 128)    0           weight_normalization_12[0][0]    
__________________________________________________________________________________________________
weight_normalization_13 (Weight (None, None, 128)    49409       leaky_re_lu_13[0][0]             
__________________________________________________________________________________________________
add_5 (Add)                     (None, None, 128)    0           weight_normalization_13[0][0]    
                                                                 add_4[0][0]                      
__________________________________________________________________________________________________
leaky_re_lu_14 (LeakyReLU)      (None, None, 128)    0           add_5[0][0]                      
__________________________________________________________________________________________________
weight_normalization_14 (Weight (None, None, 64)     131265      leaky_re_lu_14[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_15 (LeakyReLU)      (None, None, 64)     0           weight_normalization_14[0][0]    
__________________________________________________________________________________________________
weight_normalization_15 (Weight (None, None, 64)     12417       leaky_re_lu_15[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_16 (LeakyReLU)      (None, None, 64)     0           weight_normalization_15[0][0]    
__________________________________________________________________________________________________
weight_normalization_16 (Weight (None, None, 64)     12417       leaky_re_lu_16[0][0]             
__________________________________________________________________________________________________
add_6 (Add)                     (None, None, 64)     0           weight_normalization_16[0][0]    
                                                                 leaky_re_lu_15[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_17 (LeakyReLU)      (None, None, 64)     0           add_6[0][0]                      
__________________________________________________________________________________________________
weight_normalization_17 (Weight (None, None, 64)     12417       leaky_re_lu_17[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_18 (LeakyReLU)      (None, None, 64)     0           weight_normalization_17[0][0]    
__________________________________________________________________________________________________
weight_normalization_18 (Weight (None, None, 64)     12417       leaky_re_lu_18[0][0]             
__________________________________________________________________________________________________
add_7 (Add)                     (None, None, 64)     0           add_6[0][0]                      
                                                                 weight_normalization_18[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_19 (LeakyReLU)      (None, None, 64)     0           add_7[0][0]                      
__________________________________________________________________________________________________
weight_normalization_19 (Weight (None, None, 64)     12417       leaky_re_lu_19[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_20 (LeakyReLU)      (None, None, 64)     0           weight_normalization_19[0][0]    
__________________________________________________________________________________________________
weight_normalization_20 (Weight (None, None, 64)     12417       leaky_re_lu_20[0][0]             
__________________________________________________________________________________________________
add_8 (Add)                     (None, None, 64)     0           weight_normalization_20[0][0]    
                                                                 add_7[0][0]                      
__________________________________________________________________________________________________
leaky_re_lu_21 (LeakyReLU)      (None, None, 64)     0           add_8[0][0]                      
__________________________________________________________________________________________________
weight_normalization_21 (Weight (None, None, 32)     32865       leaky_re_lu_21[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_22 (LeakyReLU)      (None, None, 32)     0           weight_normalization_21[0][0]    
__________________________________________________________________________________________________
weight_normalization_22 (Weight (None, None, 32)     3137        leaky_re_lu_22[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_23 (LeakyReLU)      (None, None, 32)     0           weight_normalization_22[0][0]    
__________________________________________________________________________________________________
weight_normalization_23 (Weight (None, None, 32)     3137        leaky_re_lu_23[0][0]             
__________________________________________________________________________________________________
add_9 (Add)                     (None, None, 32)     0           weight_normalization_23[0][0]    
                                                                 leaky_re_lu_22[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_24 (LeakyReLU)      (None, None, 32)     0           add_9[0][0]                      
__________________________________________________________________________________________________
weight_normalization_24 (Weight (None, None, 32)     3137        leaky_re_lu_24[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_25 (LeakyReLU)      (None, None, 32)     0           weight_normalization_24[0][0]    
__________________________________________________________________________________________________
weight_normalization_25 (Weight (None, None, 32)     3137        leaky_re_lu_25[0][0]             
__________________________________________________________________________________________________
add_10 (Add)                    (None, None, 32)     0           add_9[0][0]                      
                                                                 weight_normalization_25[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_26 (LeakyReLU)      (None, None, 32)     0           add_10[0][0]                     
__________________________________________________________________________________________________
weight_normalization_26 (Weight (None, None, 32)     3137        leaky_re_lu_26[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_27 (LeakyReLU)      (None, None, 32)     0           weight_normalization_26[0][0]    
__________________________________________________________________________________________________
weight_normalization_27 (Weight (None, None, 32)     3137        leaky_re_lu_27[0][0]             
__________________________________________________________________________________________________
add_11 (Add)                    (None, None, 32)     0           weight_normalization_27[0][0]    
                                                                 add_10[0][0]                     
__________________________________________________________________________________________________
leaky_re_lu_28 (LeakyReLU)      (None, None, 32)     0           add_11[0][0]                     
__________________________________________________________________________________________________
weight_normalization_28 (Weight (None, None, 1)      452         leaky_re_lu_28[0][0]             
==================================================================================================
Total params: 4,646,912
Trainable params: 4,646,658
Non-trainable params: 254
__________________________________________________________________________________________________

创建鉴别器

def create_discriminator(input_shape):
    inp = keras.Input(input_shape)
    out_map1 = discriminator_block(inp)
    pool1 = layers.AveragePooling1D()(inp)
    out_map2 = discriminator_block(pool1)
    pool2 = layers.AveragePooling1D()(pool1)
    out_map3 = discriminator_block(pool2)
    return keras.Model(inp, [out_map1, out_map2, out_map3])


# 我们为鉴别器使用动态输入形状
# 这是因为生成器的输入形状未知
discriminator = create_discriminator((None, 1))

discriminator.summary()
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_2 (InputLayer)            [(None, None, 1)]    0                                            
__________________________________________________________________________________________________
average_pooling1d (AveragePooli (None, None, 1)      0           input_2[0][0]                    
__________________________________________________________________________________________________
average_pooling1d_1 (AveragePoo (None, None, 1)      0           average_pooling1d[0][0]          
__________________________________________________________________________________________________
weight_normalization_29 (Weight (None, None, 16)     273         input_2[0][0]                    
__________________________________________________________________________________________________
weight_normalization_36 (Weight (None, None, 16)     273         average_pooling1d[0][0]          
__________________________________________________________________________________________________
weight_normalization_43 (Weight (None, None, 16)     273         average_pooling1d_1[0][0]        
__________________________________________________________________________________________________
leaky_re_lu_29 (LeakyReLU)      (None, None, 16)     0           weight_normalization_29[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_35 (LeakyReLU)      (None, None, 16)     0           weight_normalization_36[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_41 (LeakyReLU)      (None, None, 16)     0           weight_normalization_43[0][0]    
__________________________________________________________________________________________________
weight_normalization_30 (Weight (None, None, 64)     10625       leaky_re_lu_29[0][0]             
__________________________________________________________________________________________________
weight_normalization_37 (Weight (None, None, 64)     10625       leaky_re_lu_35[0][0]             
__________________________________________________________________________________________________
weight_normalization_44 (Weight (None, None, 64)     10625       leaky_re_lu_41[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_30 (LeakyReLU)      (None, None, 64)     0           weight_normalization_30[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_36 (LeakyReLU)      (None, None, 64)     0           weight_normalization_37[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_42 (LeakyReLU)      (None, None, 64)     0           weight_normalization_44[0][0]    
__________________________________________________________________________________________________
weight_normalization_31 (Weight (None, None, 256)    42497       leaky_re_lu_30[0][0]             
__________________________________________________________________________________________________
weight_normalization_38 (Weight (None, None, 256)    42497       leaky_re_lu_36[0][0]             
__________________________________________________________________________________________________
weight_normalization_45 (Weight (None, None, 256)    42497       leaky_re_lu_42[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_31 (LeakyReLU)      (None, None, 256)    0           weight_normalization_31[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_37 (LeakyReLU)      (None, None, 256)    0           weight_normalization_38[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_43 (LeakyReLU)      (None, None, 256)    0           weight_normalization_45[0][0]    
__________________________________________________________________________________________________
weight_normalization_32 (Weight (None, None, 1024)   169985      leaky_re_lu_31[0][0]             
__________________________________________________________________________________________________
weight_normalization_39 (Weight (None, None, 1024)   169985      leaky_re_lu_37[0][0]             
__________________________________________________________________________________________________
weight_normalization_46 (Weight (None, None, 1024)   169985      leaky_re_lu_43[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_32 (LeakyReLU)      (None, None, 1024)   0           weight_normalization_32[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_38 (LeakyReLU)      (None, None, 1024)   0           weight_normalization_39[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_44 (LeakyReLU)      (None, None, 1024)   0           weight_normalization_46[0][0]    
__________________________________________________________________________________________________
weight_normalization_33 (Weight (None, None, 1024)   169985      leaky_re_lu_32[0][0]             
__________________________________________________________________________________________________
weight_normalization_40 (Weight (None, None, 1024)   169985      leaky_re_lu_38[0][0]             
__________________________________________________________________________________________________
weight_normalization_47 (Weight (None, None, 1024)   169985      leaky_re_lu_44[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_33 (LeakyReLU)      (None, None, 1024)   0           weight_normalization_33[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_39 (LeakyReLU)      (None, None, 1024)   0           weight_normalization_40[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_45 (LeakyReLU)      (None, None, 1024)   0           weight_normalization_47[0][0]    
__________________________________________________________________________________________________
weight_normalization_34 (Weight (None, None, 1024)   5244929     leaky_re_lu_33[0][0]             
__________________________________________________________________________________________________
weight_normalization_41 (Weight (None, None, 1024)   5244929     leaky_re_lu_39[0][0]             
__________________________________________________________________________________________________
weight_normalization_48 (Weight (None, None, 1024)   5244929     leaky_re_lu_45[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_34 (LeakyReLU)      (None, None, 1024)   0           weight_normalization_34[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_40 (LeakyReLU)      (None, None, 1024)   0           weight_normalization_41[0][0]    
__________________________________________________________________________________________________
leaky_re_lu_46 (LeakyReLU)      (None, None, 1024)   0           weight_normalization_48[0][0]    
__________________________________________________________________________________________________
weight_normalization_35 (Weight (None, None, 1)      3075        leaky_re_lu_34[0][0]             
__________________________________________________________________________________________________
weight_normalization_42 (Weight (None, None, 1)      3075        leaky_re_lu_40[0][0]             
__________________________________________________________________________________________________
weight_normalization_49 (Weight (None, None, 1)      3075        leaky_re_lu_46[0][0]             
==================================================================================================
Total params: 16,924,107
Trainable params: 16,924,086
Non-trainable params: 21
__________________________________________________________________________________________________

定义损失函数

生成器损失

生成器架构使用两种损失的组合

  1. 均方误差:

这是在一层数量为 N 的判别器输出与真实值之间计算的标准 MSE 生成器损失。

  1. 特征匹配损失:

该损失涉及从判别器中提取生成器和真实值每层的输出,并使用平均绝对误差比较每层输出 k

判别器损失

判别器使用平均绝对误差,并将真实数据预测与一相比,将生成预测与零相比。

# 生成器损失


def generator_loss(real_pred, fake_pred):
    """生成器的损失函数。

    参数:
        real_pred: 张量,通过判别器处理的真实波输出。
        fake_pred: 张量,通过判别器处理的生成器预测输出。

    返回:
        生成器的损失。
    """
    gen_loss = []
    for i in range(len(fake_pred)):
        gen_loss.append(mse(tf.ones_like(fake_pred[i][-1]), fake_pred[i][-1]))

    return tf.reduce_mean(gen_loss)


def feature_matching_loss(real_pred, fake_pred):
    """实现特征匹配损失。

    参数:
        real_pred: 张量,通过判别器处理的真实波输出。
        fake_pred: 张量,通过判别器处理的生成器预测输出。

    返回:
        特征匹配损失。
    """
    fm_loss = []
    for i in range(len(fake_pred)):
        for j in range(len(fake_pred[i]) - 1):
            fm_loss.append(mae(real_pred[i][j], fake_pred[i][j]))

    return tf.reduce_mean(fm_loss)


def discriminator_loss(real_pred, fake_pred):
    """实现判别器损失。

    参数:
        real_pred: 张量,通过判别器处理的真实波输出。
        fake_pred: 张量,通过判别器处理的生成器预测输出。

    返回:
        判别器损失。
    """
    real_loss, fake_loss = [], []
    for i in range(len(real_pred)):
        real_loss.append(mse(tf.ones_like(real_pred[i][-1]), real_pred[i][-1]))
        fake_loss.append(mse(tf.zeros_like(fake_pred[i][-1]), fake_pred[i][-1]))

    # 计算最终的判别器损失后缩放
    disc_loss = tf.reduce_mean(real_loss) + tf.reduce_mean(fake_loss)
    return disc_loss

定义 MelGAN 模型以进行训练。 该子类重写了 train_step() 方法以实现训练逻辑。

class MelGAN(keras.Model):
    def __init__(self, generator, discriminator, **kwargs):
        """MelGAN 训练类

        Args:
            generator: keras.Model, 生成器模型
            discriminator: keras.Model, 鉴别器模型
        """
        super().__init__(**kwargs)
        self.generator = generator
        self.discriminator = discriminator

    def compile(
        self,
        gen_optimizer,
        disc_optimizer,
        generator_loss,
        feature_matching_loss,
        discriminator_loss,
    ):
        """MelGAN 编译方法。

        Args:
            gen_optimizer: keras.optimizer, 用于训练的优化器
            disc_optimizer: keras.optimizer, 用于训练的优化器
            generator_loss: 可调用对象, 生成器的损失函数
            feature_matching_loss: 可调用对象, 特征匹配的损失函数
            discriminator_loss: 可调用对象, 鉴别器的损失函数
        """
        super().compile()

        # 优化器
        self.gen_optimizer = gen_optimizer
        self.disc_optimizer = disc_optimizer

        # 损失
        self.generator_loss = generator_loss
        self.feature_matching_loss = feature_matching_loss
        self.discriminator_loss = discriminator_loss

        # 追踪器
        self.gen_loss_tracker = keras.metrics.Mean(name="gen_loss")
        self.disc_loss_tracker = keras.metrics.Mean(name="disc_loss")

    def train_step(self, batch):
        x_batch_train, y_batch_train = batch

        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            # 生成音频波形
            gen_audio_wave = generator(x_batch_train, training=True)

            # 使用鉴别器生成特征
            real_pred = discriminator(y_batch_train)
            fake_pred = discriminator(gen_audio_wave)

            # 计算生成器损失
            gen_loss = generator_loss(real_pred, fake_pred)
            fm_loss = feature_matching_loss(real_pred, fake_pred)

            # 计算最终生成器损失
            gen_fm_loss = gen_loss + 10 * fm_loss

            # 计算鉴别器损失
            disc_loss = discriminator_loss(real_pred, fake_pred)

        # 计算并应用生成器和鉴别器的梯度
        grads_gen = gen_tape.gradient(gen_fm_loss, generator.trainable_weights)
        grads_disc = disc_tape.gradient(disc_loss, discriminator.trainable_weights)
        gen_optimizer.apply_gradients(zip(grads_gen, generator.trainable_weights))
        disc_optimizer.apply_gradients(zip(grads_disc, discriminator.trainable_weights))

        self.gen_loss_tracker.update_state(gen_fm_loss)
        self.disc_loss_tracker.update_state(disc_loss)

        return {
            "gen_loss": self.gen_loss_tracker.result(),
            "disc_loss": self.disc_loss_tracker.result(),
        }

训练

论文建议,使用动态形状的训练大约需要400,000步(约500个 周期)。对于这个例子,我们只运行一个周期(819步)。 更长的训练时间(超过300个周期)几乎肯定会提供更好的结果。

gen_optimizer = keras.optimizers.Adam(
    LEARNING_RATE_GEN, beta_1=0.5, beta_2=0.9, clipnorm=1
)
disc_optimizer = keras.optimizers.Adam(
    LEARNING_RATE_DISC, beta_1=0.5, beta_2=0.9, clipnorm=1
)

# 开始训练
generator = create_generator((None, 1))
discriminator = create_discriminator((None, 1))

mel_gan = MelGAN(generator, discriminator)
mel_gan.compile(
    gen_optimizer,
    disc_optimizer,
    generator_loss,
    feature_matching_loss,
    discriminator_loss,
)
mel_gan.fit(
    train_dataset.shuffle(200).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE), epochs=1
)
819/819 [==============================] - 641s 696ms/step - gen_loss: 0.9761 - disc_loss: 0.9350

<keras.callbacks.History at 0x7f8f702fe050>

测试模型

经过训练的模型现在可以用于实时文本到语音翻译任务。 为了测试MelGAN推理的速度,让我们取一个样本音频梅尔谱图并进行转换。注意,实际的模型管道将不包括MelSpec层,因此在推理期间将禁用该层。推理输入将是经过处理的梅尔谱图,类似于MelSpec层的配置。

为了测试这一点,我们将创建一个随机均匀分布的张量以模拟推理管道的行为。

# 采样一个随机张量以模拟形状为[50, 80]的128个谱图的批次
audio_sample = tf.random.uniform([128, 50, 80])

计时一次样本的推理速度。运行后,您可以看到每个谱图的平均推理时间在K80 GPU上范围是8毫秒到10毫秒,这非常快。

pred = generator.predict(audio_sample, batch_size=32, verbose=1)
4/4 [==============================] - 5s 280ms/step

结论

MelGAN是一种非常有效的光谱反演架构,其平均意见得分(MOS)为3.61,远超只有1.57的Griffin Lim算法。相比之下,MelGAN在文本到语音和语音增强任务中与现有的WaveGlow和WaveNet架构进行比较,使用的是LJSpeech和VCTK数据集 [1]

本教程强调了:

  1. 使用随着滤波器大小增加的膨胀卷积的优点
  2. 实现一个自定义层,用于实时转换音频波形为梅尔谱图
  3. 使用特征匹配损失函数进行GAN生成器训练的有效性。

进一步阅读

  1. MelGAN论文(Kundan Kumar等),以理解架构和训练过程的理由
  2. 要深入了解特征匹配损失,可以参考改进的GAN训练技术(Tim Salimans等)。

HuggingFace上有示例

训练模型 演示
Generic badge Generic badge