from __future__ import annotations
from fastai.basics import *
from import *

### 默认类级别 3


from nbdev.showdoc import *


GAN 代表 生成对抗网络,由 Ian Goodfellow 发明。其概念是我们同时训练两个模型:一个生成器和一个鉴别器。生成器会尝试生成与数据集中相似的新图像,而鉴别器则会尝试区分真实图像和生成器生成的图像。生成器输出图像,鉴别器输出一个数字(通常是一个概率,真实图像为 1,假图像为 0)。


  1. 冻结生成器,训练鉴别器一步:
  1. 冻结鉴别器,训练生成器一步:



class GANModule(Module):
    "Wrapper around a `generator` and a `critic` to create a GAN."
    def __init__(self,
        generator:nn.Module=None, # 生成器 PyTorch 模块
        critic:nn.Module=None, # 判别器 PyTorch 模块
        gen_mode:None|bool=False # 是否应将GAN设置为生成器模式
        if generator is not None: self.generator=generator
        if critic    is not None: self.critic   =critic

    def forward(self, *args):
        return self.generator(*args) if self.gen_mode else self.critic(*args)

    def switch(self,
        gen_mode:None|bool=None # 是否应将GAN设置为生成器模式
        "Put the module in generator mode if `gen_mode` is `True`, in critic mode otherwise."
        self.gen_mode = (not self.gen_mode) if gen_mode is None else gen_mode




GANModule.switch(gen_mode:(None, <class 'bool'>)=None)

Put the module in generator mode if gen_mode is True, in critic mode otherwise.

Type Default Details
gen_mode (None, bool) None Whether the GAN should be set to generator mode

默认情况下(将 gen_mode 留空为 None),这将使模块进入另一种模式(如果它处于生成器模式,则进入评判者模式,反之亦然)。

def basic_critic(
    in_size:int, # 批评者的输入尺寸(与生成器的输出尺寸相同)
    n_channels:int, # 评论者的输入通道数
    n_features:int=64, # 评论中使用的特征数量
    n_extra_layers:int=0, # 评论者中额外的隐藏层数量
    norm_type:NormType=NormType.Batch, # 在评论者中使用的归一化类型
) -> nn.Sequential:
    "A basic critic for images `n_channels` x `in_size` x `in_size`."
    layers = [ConvLayer(n_channels, n_features, 4, 2, 1, norm_type=None, **kwargs)]
    cur_size, cur_ftrs = in_size//2, n_features
    layers += [ConvLayer(cur_ftrs, cur_ftrs, 3, 1, norm_type=norm_type, **kwargs) for _ in range(n_extra_layers)]
    while cur_size > 4:
        layers.append(ConvLayer(cur_ftrs, cur_ftrs*2, 4, 2, 1, norm_type=norm_type, **kwargs))
        cur_ftrs *= 2 ; cur_size //= 2
    init = kwargs.get('init', nn.init.kaiming_normal_)
    layers += [init_default(nn.Conv2d(cur_ftrs, 1, 4, padding=0), init), Flatten()]
    return nn.Sequential(*layers)
class AddChannels(Module):
    "Add `n_dim` channels at the end of the input."
    def __init__(self, n_dim): self.n_dim=n_dim
    def forward(self, x): return x.view(*(list(x.shape)+[1]*self.n_dim))
def basic_generator(
    out_size:int, # 生成器的输出尺寸(与判别器的输入尺寸相同)
    n_channels:int, # 生成器输出通道的数量
    in_sz:int=100, # 生成器输入噪声向量的尺寸
    n_features:int=64, # 生成器中使用的特征数量
    n_extra_layers:int=0, # 生成器中额外隐藏层的数量
) -> nn.Sequential:
    "A basic generator from `in_sz` to images `n_channels` x `out_size` x `out_size`."
    cur_size, cur_ftrs = 4, n_features//2
    while cur_size < out_size:  cur_size *= 2; cur_ftrs *= 2
    layers = [AddChannels(2), ConvLayer(in_sz, cur_ftrs, 4, 1, transpose=True, **kwargs)]
    cur_size = 4
    while cur_size < out_size // 2:
        layers.append(ConvLayer(cur_ftrs, cur_ftrs//2, 4, 2, 1, transpose=True, **kwargs))
        cur_ftrs //= 2; cur_size *= 2
    layers += [ConvLayer(cur_ftrs, cur_ftrs, 3, 1, 1, transpose=True, **kwargs) for _ in range(n_extra_layers)]
    layers += [nn.ConvTranspose2d(cur_ftrs, n_channels, 4, 2, 1, bias=False), nn.Tanh()]
    return nn.Sequential(*layers)
critic = basic_critic(64, 3)
generator = basic_generator(64, 3)
tst = GANModule(critic=critic, generator=generator)
real = torch.randn(2, 3, 64, 64)
real_p = tst(real)
test_eq(real_p.shape, [2,1])

tst.switch() #tst 现已进入生成器模式
noise = torch.randn(2, 100)
fake = tst(noise)
test_eq(fake.shape, real.shape)

tst.switch() #tst 已重新进入评论模式
fake_p = tst(fake)
test_eq(fake_p.shape, [2,1])
_conv_args = dict(act_cls = partial(nn.LeakyReLU, negative_slope=0.2), norm_type=NormType.Spectral)

def _conv(ni, nf, ks=3, stride=1, self_attention=False, **kwargs):
    if self_attention: kwargs['xtra'] = SelfAttention(nf)
    return ConvLayer(ni, nf, ks=ks, stride=stride, **_conv_args, **kwargs)
def DenseResBlock(
    nf:int, # 特征数量
    norm_type:NormType=NormType.Batch, # 归一化类型
) -> SequentialEx:
    "Resnet block of `nf` features. `conv_kwargs` are passed to `conv_layer`."
    return SequentialEx(ConvLayer(nf, nf, norm_type=norm_type, **kwargs),
                        ConvLayer(nf, nf, norm_type=norm_type, **kwargs),
def gan_critic(
    n_channels:int=3, # 评论者的输入通道数
    nf:int=128, # 评论家特征数量
    n_blocks:int=3, # 判别器中ResNet块的数量
    p:float=0.15 # 评论家中的丢弃量
) -> nn.Sequential:
    "Critic to train a `GAN`."
    layers = [
        _conv(n_channels, nf, ks=4, stride=2),
        DenseResBlock(nf, **_conv_args)]
    nf *= 2 # 经过密集区块
    for i in range(n_blocks):
        layers += [
            _conv(nf, nf*2, ks=4, stride=2, self_attention=(i==0))]
        nf *= 2
    layers += [
        ConvLayer(nf, 1, ks=4, bias=False, padding=0, norm_type=NormType.Spectral, act_cls=None),
    return nn.Sequential(*layers)
class GANLoss(GANModule):
    "Wrapper around `crit_loss_func` and `gen_loss_func`"
    def __init__(self,
        gen_loss_func:callable, # 生成器损失函数
        crit_loss_func:callable, # 批评损失函数
        gan_model:GANModule # GAN模型

    def generator(self,
        output, # 发电机输出
        target # 实像
        "Evaluate the `output` with the critic then uses `self.gen_loss_func` to evaluate how well the critic was fooled by `output`"
        fake_pred = self.gan_model.critic(output)
        self.gen_loss = self.gen_loss_func(fake_pred, output, target)
        return self.gen_loss

    def critic(self,
        real_pred, # 真实图像的评论预测
        input # 输入噪声向量以传递给生成器
        "Create some `fake_pred` with the generator from `input` and compare them to `real_pred` in `self.crit_loss_func`."
        fake = self.gan_model.generator(input).requires_grad_(False)
        fake_pred = self.gan_model.critic(fake)
        self.crit_loss = self.crit_loss_func(real_pred, fake_pred)
        return self.crit_loss


GANLoss.generator(output, target)

Evaluate the output with the critic then uses self.gen_loss_func to evaluate how well the critic was fooled by output

Type Default Details
output Generator outputs
target Real images


GANLoss.critic(real_pred, input)

Create some fake_pred with the generator from input and compare them to real_pred in self.crit_loss_func.

Type Default Details
real_pred Critic predictions for real images
input Input noise vector to pass into generator


def gen_loss_func(fake_pred, output, target):



def crit_loss_func(real_pred, fake_pred):


class AdaptiveLoss(Module):
    "Expand the `target` to match the `output` size before applying `crit`."
    def __init__(self, crit:callable): self.crit = crit
    def forward(self, output:Tensor, target:Tensor):
        return self.crit(output, target[:,None].expand_as(output).float())
def accuracy_thresh_expand(y_pred:Tensor, y_true:Tensor, thresh:float=0.5, sigmoid:bool=True):
    "Compute thresholded accuracy after expanding `y_true` to the size of `y_pred`."
    if sigmoid: y_pred = y_pred.sigmoid()
    return ((y_pred>thresh).byte()==y_true[:,None].expand_as(y_pred).byte()).float().mean()


def set_freeze_model(
    m:nn.Module, # 模型冻结/解冻
    rg:bool # `requires_grad` 参数。设置为 `True` 表示冻结。
    for p in m.parameters(): p.requires_grad_(rg)
class GANTrainer(Callback):
    "Callback to handle GAN Training."
    run_after = TrainEvalCallback
    def __init__(self,
        switch_eval:bool=False, # 在计算损失时是否应将模型设置为评估模式
        clip:None|float=None, # 剪掉多少权重
        beta:float=0.98, # 损失的指数加权平滑参数 `beta`
        gen_first:bool=False, # 无论我们从生成器训练开始
        show_img:bool=True, # 是否在训练过程中展示生成的示例图像
        self.gen_loss,self.crit_loss = AvgSmoothLoss(beta=beta),AvgSmoothLoss(beta=beta)

    def _set_trainable(self):
        "Appropriately set the generator and critic into a trainable or loss evaluation mode based on `self.gen_mode`."
        train_model = self.generator if     self.gen_mode else self.critic
        loss_model  = self.generator if not self.gen_mode else self.critic
        set_freeze_model(train_model, True)
        set_freeze_model(loss_model, False)
        if self.switch_eval:

    def before_fit(self):
        self.generator,self.critic = self.model.generator,self.model.critic
        self.gen_mode = self.gen_first
        self.crit_losses,self.gen_losses = [],[]
        self.gen_loss.reset() ; self.crit_loss.reset()
        #self.recorder.no_val = True
        #self.recorder.add_metric_names(['gen_loss', 'disc_loss'])
        #self.imgs, self.titles = [], []

    def before_validate(self):
        "Switch in generator mode for showing results."

    def before_batch(self):
        "Clamp the weights with `self.clip` if it's not None, set the correct input/target."
        if and self.clip is not None:
            for p in self.critic.parameters():, self.clip)
        if not self.gen_mode:
            (self.learn.xb,self.learn.yb) = (self.yb,self.xb)

    def after_batch(self):
        "Record `last_loss` in the proper list."
        if not return
        if self.gen_mode:
            self.last_gen = self.learn.to_detach(self.pred)

    def before_epoch(self):
        "Put the critic or the generator back to eval if necessary."

    #def after_epoch(self):
    #    "Show a sample image."
    #    if not hasattr(self, 'last_gen') or not self.show_img: return
    #    data =
    #    img = self.last_gen[0]
    #    norm = getattr(data,'norm',False)
    #    if norm and norm.keywords.get('do_y',False): img = data.denorm(img)
    #    img = data.train_ds.y.reconstruct(img)
    #    self.imgs.append(img)
    #    self.titles.append(f'Epoch {epoch}')
    #    pbar.show_imgs(self.imgs, self.titles)
    #    return add_metrics(last_metrics, [getattr(self.smoothenerG,'smooth',None),getattr(self.smoothenerC,'smooth',None)])

    def switch(self, gen_mode=None):
        "Switch the model and loss function, if `gen_mode` is provided, in the desired mode."
        self.gen_mode = (not self.gen_mode) if gen_mode is None else gen_mode


class FixedGANSwitcher(Callback):
    "Switcher to do `n_crit` iterations of the critic then `n_gen` iterations of the generator."
    run_after = GANTrainer
    def __init__(self,
        n_crit:int=1, # 在切换到生成器之前,需要进行多少步的批评者训练
        n_gen:int=1 # 在切换到判别器之前,生成器需要训练多少步

    def before_train(self): self.n_c,self.n_g = 0,0

    def after_batch(self):
        "Switch the model if necessary."
        if not return
        if self.learn.gan_trainer.gen_mode:
            self.n_g += 1
            n_iter,n_in,n_out = self.n_gen,self.n_c,self.n_g
            self.n_c += 1
            n_iter,n_in,n_out = self.n_crit,self.n_g,self.n_c
        target = n_iter if isinstance(n_iter, int) else n_iter(n_in)
        if target == n_out:
            self.n_c,self.n_g = 0,0
class AdaptiveGANSwitcher(Callback):
    "Switcher that goes back to generator/critic when the loss goes below `gen_thresh`/`crit_thresh`."
    run_after = GANTrainer
    def __init__(self,
        gen_thresh:None|float=None, # 发电机损耗阈值
        critic_thresh:None|float=None # 评论家损失阈值

    def after_batch(self):
        "Switch the model if necessary."
        if not return
        if self.gan_trainer.gen_mode:
            if self.gen_thresh is None or self.loss < self.gen_thresh: self.gan_trainer.switch()
            if self.critic_thresh is None or self.loss < self.critic_thresh: self.gan_trainer.switch()
class GANDiscriminativeLR(Callback):
    "`Callback` that handles multiplying the learning rate by `mult_lr` for the critic."
    run_after = GANTrainer
    def __init__(self, mult_lr=5.): self.mult_lr = mult_lr

    def before_batch(self):
        "Multiply the current lr if necessary."
        if not self.learn.gan_trainer.gen_mode and
            self.learn.opt.set_hyper('lr', self.learn.opt.hypers[0]['lr']*self.mult_lr)

    def after_batch(self):
        "Put the LR back to its value if necessary."
        if not self.learn.gan_trainer.gen_mode: self.learn.opt.set_hyper('lr', self.learn.opt.hypers[0]['lr']/self.mult_lr)

GAN 数据

class InvisibleTensor(TensorBase):
    "TensorBase but show method does nothing"
    def show(self, ctx=None, **kwargs): return ctx
def generate_noise(
    fn, # 虚拟参数,以便与 `DataBlock` 兼容
    size=100 # 返回噪声向量的尺寸
) -> InvisibleTensor:
    "Generate noise vector."
    return cast(torch.randn(size), InvisibleTensor)

我们使用 generate_noise 函数生成噪声向量,以传递给生成器进行图像生成。

def show_batch(x:InvisibleTensor, y:TensorImage, samples, ctxs=None, max_n=10, nrows=None, ncols=None, figsize=None, **kwargs):
    if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)
    ctxs = show_batch[object](x, y, samples, ctxs=ctxs, max_n=max_n, **kwargs)
    return ctxs
def show_results(x:InvisibleTensor, y:TensorImage, samples, outs, ctxs=None, max_n=10, nrows=None, ncols=None, figsize=None, **kwargs):
    if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)
    ctxs = [, **kwargs) for b,c,_ in zip(outs.itemgot(0),ctxs,range(max_n))]
    return ctxs
bs = 128
size = 64
dblock = DataBlock(blocks = (TransformBlock, ImageBlock),
                   get_x = generate_noise,
                   get_items = get_image_files,
                   splitter = IndexSplitter([]),
                   item_tfms=Resize(size, method=ResizeMethod.Crop), 
                   batch_tfms = Normalize.from_stats(torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])))
path = untar_data(URLs.LSUN_BEDROOMS)
dls = dblock.dataloaders(path, path=path, bs=bs)

GAN 学习器

def gan_loss_from_func(
    loss_gen:callable, # 生成器的损失函数。评估生成器输出图像与目标真实图像。
    loss_crit:callable, # 用于评价判别器的损失函数。评估真实图像和生成图像的预测结果。
    weights_gen:None|MutableSequence|tuple=None # 生成器和判别器损失函数的权重
    "Define loss functions for a GAN from `loss_gen` and `loss_crit`."
    def _loss_G(fake_pred, output, target, weights_gen=weights_gen):
        ones = fake_pred.new_ones(fake_pred.shape[0])
        weights_gen = ifnone(weights_gen, (1.,1.))
        return weights_gen[0] * loss_crit(fake_pred, ones) + weights_gen[1] * loss_gen(output, target)

    def _loss_C(real_pred, fake_pred):
        ones  = real_pred.new_ones (real_pred.shape[0])
        zeros = fake_pred.new_zeros(fake_pred.shape[0])
        return (loss_crit(real_pred, ones) + loss_crit(fake_pred, zeros)) / 2

    return _loss_G, _loss_C
def _tk_mean(fake_pred, output, target): return fake_pred.mean()
def _tk_diff(real_pred, fake_pred): return real_pred.mean() - fake_pred.mean()
class GANLearner(Learner):
    "A `Learner` suitable for GANs."
    def __init__(self,
        dls:DataLoaders, # 用于GAN数据的DataLoaders对象
        generator:nn.Module, # 发电机模型
        critic:nn.Module, # 批评模型
        gen_loss_func:callable, # 生成器损失函数
        crit_loss_func:callable, # 批评损失函数
        switcher:Callback|None=None, # 用于在生成器和判别器训练之间切换的回调函数,默认为 `FixedGANSwitcher`。
        gen_first:bool=False, # 无论我们从生成器训练开始
        switch_eval:bool=True, # 在计算损失时是否应将模型设置为评估模式
        show_img:bool=True, # 是否在训练过程中展示生成的示例图像
        clip:None|float=None, # 剪裁权重多少
        cbs:Callback|None|MutableSequence=None, # 其他回调函数
        metrics:None|MutableSequence|callable=None, # 指标
        gan = GANModule(generator, critic)
        loss_func = GANLoss(gen_loss_func, crit_loss_func, gan)
        if switcher is None: switcher = FixedGANSwitcher()
        trainer = GANTrainer(clip=clip, switch_eval=switch_eval, gen_first=gen_first, show_img=show_img)
        cbs = L(cbs) + L(trainer, switcher)
        metrics = L(metrics) + L(*LossMetrics('gen_loss,crit_loss'))
        super().__init__(dls, gan, loss_func=loss_func, cbs=cbs, metrics=metrics, **kwargs)

    def from_learners(cls,
        gen_learn:Learner, # 一个包含生成器的`Learner`对象
        crit_learn:Learner, # 一个包含评价器的`学习者`对象
        switcher:Callback|None=None, # 用于在生成器和判别器训练之间切换的回调函数,默认为 `FixedGANSwitcher`。
        weights_gen:None|MutableSequence|tuple=None, # 生成器和判别器损失函数的权重
        "Create a GAN from `learn_gen` and `learn_crit`."
        losses = gan_loss_from_func(gen_learn.loss_func, crit_learn.loss_func, weights_gen=weights_gen)
        return cls(gen_learn.dls, gen_learn.model, crit_learn.model, *losses, switcher=switcher, **kwargs)

    def wgan(cls,
        dls:DataLoaders, # 用于GAN数据的DataLoaders对象
        generator:nn.Module, # 发电机模型
        critic:nn.Module, # 批评模型
        switcher:Callback|None=None, # 用于在生成器和判别器训练之间切换的回调函数,默认为 `FixedGANSwitcher(n_crit=5, n_gen=1)`。
        clip:None|float=0.01, # 剪裁权重多少
        switch_eval:bool=False, # 在计算损失时是否应将模型设置为评估模式
        "Create a [WGAN]( from `dls`, `generator` and `critic`."
        if switcher is None: switcher = FixedGANSwitcher(n_crit=5, n_gen=1)
        return cls(dls, generator, critic, _tk_mean, _tk_diff, switcher=switcher, clip=clip, switch_eval=switch_eval, **kwargs)

GANLearner.from_learners = delegates(to=GANLearner.__init__)(GANLearner.from_learners)
GANLearner.wgan = delegates(to=GANLearner.__init__)(GANLearner.wgan)


GANLearner.from_learners(gen_learn:Learner, crit_learn:Learner, switcher:Callback'>, None)=None, weights_gen:(None, <class 'list'>, <class 'tuple'>)=None, gen_first:bool=False, switch_eval:bool=True, show_img:bool=True, clip:(None, <class 'float'>)=None, cbs:Callback'>, None, <class 'list'>)=None, metrics:(None, <class 'list'>, <built-in function callable>)=None, loss_func=None, opt_func=Adam, lr=0.001, splitter=trainable_params, path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95, 0.85, 0.95))

Create a GAN from learn_gen and learn_crit.

Type Default Details
gen_learn Learner A Learner object that has the generator
crit_learn Learner A Learner object that has the critic
switcher (Callback, None) None Callback for switching between generator and critic training, defaults to FixedGANSwitcher
weights_gen (None, list, tuple) None Weights for the generator and critic loss function
gen_first bool False No Content
switch_eval bool True No Content
show_img bool True No Content
clip (None, float) None No Content
cbs (Callback, None, list) None No Content
metrics (None, list, callable) None No Content
loss_func NoneType None No Content
opt_func function <function Adam> No Content
lr float 0.001 No Content
splitter function <function trainable_params> No Content
path NoneType None No Content
model_dir str models No Content
wd NoneType None No Content
wd_bn_bias bool False No Content
train_bn bool True No Content
moms tuple (0.95, 0.85, 0.95) No Content


GANLearner.wgan(dls:DataLoaders, generator:Module, critic:Module, switcher:Callback'>, None)=None, clip:(None, <class 'float'>)=0.01, switch_eval:bool=False, gen_first:bool=False, show_img:bool=True, cbs:Callback'>, None, <class 'list'>)=None, metrics:(None, <class 'list'>, <built-in function callable>)=None, loss_func=None, opt_func=Adam, lr=0.001, splitter=trainable_params, path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95, 0.85, 0.95))

Create a WGAN from dls, generator and critic.

Type Default Details
dls DataLoaders DataLoaders object for GAN data
generator Module Generator model
critic Module Critic model
switcher (Callback, None) None Callback for switching between generator and critic training, defaults to FixedGANSwitcher(n_crit=5, n_gen=1)
clip (None, float) 0.01 How much to clip the weights
switch_eval bool False Whether the model should be set to eval mode when calculating loss
gen_first bool False No Content
show_img bool True No Content
cbs (Callback, None, list) None No Content
metrics (None, list, callable) None No Content
loss_func NoneType None No Content
opt_func function <function Adam> No Content
lr float 0.001 No Content
splitter function <function trainable_params> No Content
path NoneType None No Content
model_dir str models No Content
wd NoneType None No Content
wd_bn_bias bool False No Content
train_bn bool True No Content
moms tuple (0.95, 0.85, 0.95) No Content
from fastai.callback.all import *
generator = basic_generator(64, n_channels=3, n_extra_layers=1)
critic    = basic_critic   (64, n_channels=3, n_extra_layers=1, act_cls=partial(nn.LeakyReLU, negative_slope=0.2))
learn = GANLearner.wgan(dls, generator, critic, opt_func = RMSProp)
learn.recorder.valid_metrics=False, 2e-4, wd=0.)
epoch train_loss gen_loss crit_loss time
0 -0.815071 0.646809 -1.140522 00:38
learn.show_results(max_n=9, ds_idx=0)

