GAN

! [ -e /content ] && pip install -Uqq fastai  # 在Colab上升级fastai
from __future__ import annotations
from fastai.basics import *
from fastai.vision.all import *

::: {#cell-3 .cell 0=‘d’ 1=‘e’ 2=‘f’ 3=‘a’ 4=‘u’ 5=‘l’ 6=‘t’ 7=’_’ 8=‘e’ 9=‘x’ 10=‘p’ 11=’ ’ 12=‘视’ 13=‘觉’ 14=‘.’ 15=‘生’ 16=‘成’ 17=‘对’ 18=‘抗’ 19=‘网’ 20=‘络’}

### 默认类级别 3

:::

from nbdev.showdoc import *

生成对抗网络的基本支持

GAN 代表 生成对抗网络,由 Ian Goodfellow 发明。其概念是我们同时训练两个模型:一个生成器和一个鉴别器。生成器会尝试生成与数据集中相似的新图像,而鉴别器则会尝试区分真实图像和生成器生成的图像。生成器输出图像,鉴别器输出一个数字(通常是一个概率,真实图像为 1,假图像为 0)。

我们以相互对抗的方式训练它们,具体步骤如下(或多或少):

  1. 冻结生成器,训练鉴别器一步:
  1. 冻结鉴别器,训练生成器一步:
Note

fastai库通过GANTrainer提供支持用于训练GANs,但不包含超过基本模型的内容。

封装模块

class GANModule(Module):
    "Wrapper around a `generator` and a `critic` to create a GAN."
    def __init__(self,
        generator:nn.Module=None, # 生成器 PyTorch 模块
        critic:nn.Module=None, # 判别器 PyTorch 模块
        gen_mode:None|bool=False # 是否应将GAN设置为生成器模式
    ):
        if generator is not None: self.generator=generator
        if critic    is not None: self.critic   =critic
        store_attr('gen_mode')

    def forward(self, *args):
        return self.generator(*args) if self.gen_mode else self.critic(*args)

    def switch(self,
        gen_mode:None|bool=None # 是否应将GAN设置为生成器模式
    ):
        "Put the module in generator mode if `gen_mode` is `True`, in critic mode otherwise."
        self.gen_mode = (not self.gen_mode) if gen_mode is None else gen_mode

这只是一个包含两个模型的外壳。当被调用时,它将根据gen_mode的值将输入委托给generatorcritic

show_doc(GANModule.switch)

GANModule.switch[source]

GANModule.switch(gen_mode:(None, <class 'bool'>)=None)

Put the module in generator mode if gen_mode is True, in critic mode otherwise.

Type Default Details
gen_mode (None, bool) None Whether the GAN should be set to generator mode

默认情况下(将 gen_mode 留空为 None),这将使模块进入另一种模式(如果它处于生成器模式,则进入评判者模式,反之亦然)。

@delegates(ConvLayer.__init__)
def basic_critic(
    in_size:int, # 批评者的输入尺寸(与生成器的输出尺寸相同)
    n_channels:int, # 评论者的输入通道数
    n_features:int=64, # 评论中使用的特征数量
    n_extra_layers:int=0, # 评论者中额外的隐藏层数量
    norm_type:NormType=NormType.Batch, # 在评论者中使用的归一化类型
    **kwargs
) -> nn.Sequential:
    "A basic critic for images `n_channels` x `in_size` x `in_size`."
    layers = [ConvLayer(n_channels, n_features, 4, 2, 1, norm_type=None, **kwargs)]
    cur_size, cur_ftrs = in_size//2, n_features
    layers += [ConvLayer(cur_ftrs, cur_ftrs, 3, 1, norm_type=norm_type, **kwargs) for _ in range(n_extra_layers)]
    while cur_size > 4:
        layers.append(ConvLayer(cur_ftrs, cur_ftrs*2, 4, 2, 1, norm_type=norm_type, **kwargs))
        cur_ftrs *= 2 ; cur_size //= 2
    init = kwargs.get('init', nn.init.kaiming_normal_)
    layers += [init_default(nn.Conv2d(cur_ftrs, 1, 4, padding=0), init), Flatten()]
    return nn.Sequential(*layers)
class AddChannels(Module):
    "Add `n_dim` channels at the end of the input."
    def __init__(self, n_dim): self.n_dim=n_dim
    def forward(self, x): return x.view(*(list(x.shape)+[1]*self.n_dim))
@delegates(ConvLayer.__init__)
def basic_generator(
    out_size:int, # 生成器的输出尺寸(与判别器的输入尺寸相同)
    n_channels:int, # 生成器输出通道的数量
    in_sz:int=100, # 生成器输入噪声向量的尺寸
    n_features:int=64, # 生成器中使用的特征数量
    n_extra_layers:int=0, # 生成器中额外隐藏层的数量
    **kwargs
) -> nn.Sequential:
    "A basic generator from `in_sz` to images `n_channels` x `out_size` x `out_size`."
    cur_size, cur_ftrs = 4, n_features//2
    while cur_size < out_size:  cur_size *= 2; cur_ftrs *= 2
    layers = [AddChannels(2), ConvLayer(in_sz, cur_ftrs, 4, 1, transpose=True, **kwargs)]
    cur_size = 4
    while cur_size < out_size // 2:
        layers.append(ConvLayer(cur_ftrs, cur_ftrs//2, 4, 2, 1, transpose=True, **kwargs))
        cur_ftrs //= 2; cur_size *= 2
    layers += [ConvLayer(cur_ftrs, cur_ftrs, 3, 1, 1, transpose=True, **kwargs) for _ in range(n_extra_layers)]
    layers += [nn.ConvTranspose2d(cur_ftrs, n_channels, 4, 2, 1, bias=False), nn.Tanh()]
    return nn.Sequential(*layers)
critic = basic_critic(64, 3)
generator = basic_generator(64, 3)
tst = GANModule(critic=critic, generator=generator)
real = torch.randn(2, 3, 64, 64)
real_p = tst(real)
test_eq(real_p.shape, [2,1])

tst.switch() #tst 现已进入生成器模式
noise = torch.randn(2, 100)
fake = tst(noise)
test_eq(fake.shape, real.shape)

tst.switch() #tst 已重新进入评论模式
fake_p = tst(fake)
test_eq(fake_p.shape, [2,1])
_conv_args = dict(act_cls = partial(nn.LeakyReLU, negative_slope=0.2), norm_type=NormType.Spectral)

def _conv(ni, nf, ks=3, stride=1, self_attention=False, **kwargs):
    if self_attention: kwargs['xtra'] = SelfAttention(nf)
    return ConvLayer(ni, nf, ks=ks, stride=stride, **_conv_args, **kwargs)
@delegates(ConvLayer)
def DenseResBlock(
    nf:int, # 特征数量
    norm_type:NormType=NormType.Batch, # 归一化类型
    **kwargs
) -> SequentialEx:
    "Resnet block of `nf` features. `conv_kwargs` are passed to `conv_layer`."
    return SequentialEx(ConvLayer(nf, nf, norm_type=norm_type, **kwargs),
                        ConvLayer(nf, nf, norm_type=norm_type, **kwargs),
                        MergeLayer(dense=True))
def gan_critic(
    n_channels:int=3, # 评论者的输入通道数
    nf:int=128, # 评论家特征数量
    n_blocks:int=3, # 判别器中ResNet块的数量
    p:float=0.15 # 评论家中的丢弃量
) -> nn.Sequential:
    "Critic to train a `GAN`."
    layers = [
        _conv(n_channels, nf, ks=4, stride=2),
        nn.Dropout2d(p/2),
        DenseResBlock(nf, **_conv_args)]
    nf *= 2 # 经过密集区块
    for i in range(n_blocks):
        layers += [
            nn.Dropout2d(p),
            _conv(nf, nf*2, ks=4, stride=2, self_attention=(i==0))]
        nf *= 2
    layers += [
        ConvLayer(nf, 1, ks=4, bias=False, padding=0, norm_type=NormType.Spectral, act_cls=None),
        Flatten()]
    return nn.Sequential(*layers)
class GANLoss(GANModule):
    "Wrapper around `crit_loss_func` and `gen_loss_func`"
    def __init__(self,
        gen_loss_func:callable, # 生成器损失函数
        crit_loss_func:callable, # 批评损失函数
        gan_model:GANModule # GAN模型
    ):
        super().__init__()
        store_attr('gen_loss_func,crit_loss_func,gan_model')

    def generator(self,
        output, # 发电机输出
        target # 实像
    ):
        "Evaluate the `output` with the critic then uses `self.gen_loss_func` to evaluate how well the critic was fooled by `output`"
        fake_pred = self.gan_model.critic(output)
        self.gen_loss = self.gen_loss_func(fake_pred, output, target)
        return self.gen_loss

    def critic(self,
        real_pred, # 真实图像的评论预测
        input # 输入噪声向量以传递给生成器
    ):
        "Create some `fake_pred` with the generator from `input` and compare them to `real_pred` in `self.crit_loss_func`."
        fake = self.gan_model.generator(input).requires_grad_(False)
        fake_pred = self.gan_model.critic(fake)
        self.crit_loss = self.crit_loss_func(real_pred, fake_pred)
        return self.crit_loss
show_doc(GANLoss.generator)

GANLoss.generator[source]

GANLoss.generator(output, target)

Evaluate the output with the critic then uses self.gen_loss_func to evaluate how well the critic was fooled by output

Type Default Details
output Generator outputs
target Real images
show_doc(GANLoss.critic)

GANLoss.critic[source]

GANLoss.critic(real_pred, input)

Create some fake_pred with the generator from input and compare them to real_pred in self.crit_loss_func.

Type Default Details
real_pred Critic predictions for real images
input Input noise vector to pass into generator

如果调用generator方法,则该损失函数期望接收生成器的output和一些target(一批真实图像)。它将使用gen_loss_func评估生成器是否成功欺骗了鉴别器。该损失函数具有以下签名

def gen_loss_func(fake_pred, output, target):

以便能够将鉴别器对output的输出(第一个参数fake_pred)与outputtarget结合起来(例如,如果你想将GAN损失与其他损失混合)。

如果调用critic方法,则该损失函数期望接收鉴别器给出的real_pred和一些input(馈送给生成器的噪声)。它将使用crit_loss_func评估鉴别器。该损失函数具有以下签名

def crit_loss_func(real_pred, fake_pred):

其中real_pred是鉴别器对一批真实图像的输出,而fake_pred是通过生成器从噪声生成的。

class AdaptiveLoss(Module):
    "Expand the `target` to match the `output` size before applying `crit`."
    def __init__(self, crit:callable): self.crit = crit
    def forward(self, output:Tensor, target:Tensor):
        return self.crit(output, target[:,None].expand_as(output).float())
def accuracy_thresh_expand(y_pred:Tensor, y_true:Tensor, thresh:float=0.5, sigmoid:bool=True):
    "Compute thresholded accuracy after expanding `y_true` to the size of `y_pred`."
    if sigmoid: y_pred = y_pred.sigmoid()
    return ((y_pred>thresh).byte()==y_true[:,None].expand_as(y_pred).byte()).float().mean()

GAN训练的回调函数

def set_freeze_model(
    m:nn.Module, # 模型冻结/解冻
    rg:bool # `requires_grad` 参数。设置为 `True` 表示冻结。
):
    for p in m.parameters(): p.requires_grad_(rg)
class GANTrainer(Callback):
    "Callback to handle GAN Training."
    run_after = TrainEvalCallback
    def __init__(self,
        switch_eval:bool=False, # 在计算损失时是否应将模型设置为评估模式
        clip:None|float=None, # 剪掉多少权重
        beta:float=0.98, # 损失的指数加权平滑参数 `beta`
        gen_first:bool=False, # 无论我们从生成器训练开始
        show_img:bool=True, # 是否在训练过程中展示生成的示例图像
    ):
        store_attr('switch_eval,clip,gen_first,show_img')
        self.gen_loss,self.crit_loss = AvgSmoothLoss(beta=beta),AvgSmoothLoss(beta=beta)

    def _set_trainable(self):
        "Appropriately set the generator and critic into a trainable or loss evaluation mode based on `self.gen_mode`."
        train_model = self.generator if     self.gen_mode else self.critic
        loss_model  = self.generator if not self.gen_mode else self.critic
        set_freeze_model(train_model, True)
        set_freeze_model(loss_model, False)
        if self.switch_eval:
            train_model.train()
            loss_model.eval()

    def before_fit(self):
        "Initialization."
        self.generator,self.critic = self.model.generator,self.model.critic
        self.gen_mode = self.gen_first
        self.switch(self.gen_mode)
        self.crit_losses,self.gen_losses = [],[]
        self.gen_loss.reset() ; self.crit_loss.reset()
        #self.recorder.no_val = True
        #self.recorder.add_metric_names(['gen_loss', 'disc_loss'])
        #self.imgs, self.titles = [], []

    def before_validate(self):
        "Switch in generator mode for showing results."
        self.switch(gen_mode=True)

    def before_batch(self):
        "Clamp the weights with `self.clip` if it's not None, set the correct input/target."
        if self.training and self.clip is not None:
            for p in self.critic.parameters(): p.data.clamp_(-self.clip, self.clip)
        if not self.gen_mode:
            (self.learn.xb,self.learn.yb) = (self.yb,self.xb)

    def after_batch(self):
        "Record `last_loss` in the proper list."
        if not self.training: return
        if self.gen_mode:
            self.gen_loss.accumulate(self.learn)
            self.gen_losses.append(self.gen_loss.value)
            self.last_gen = self.learn.to_detach(self.pred)
        else:
            self.crit_loss.accumulate(self.learn)
            self.crit_losses.append(self.crit_loss.value)

    def before_epoch(self):
        "Put the critic or the generator back to eval if necessary."
        self.switch(self.gen_mode)

    #def after_epoch(self):
    #    "Show a sample image."
    #    if not hasattr(self, 'last_gen') or not self.show_img: return
    #    data = self.learn.data
    #    img = self.last_gen[0]
    #    norm = getattr(data,'norm',False)
    #    if norm and norm.keywords.get('do_y',False): img = data.denorm(img)
    #    img = data.train_ds.y.reconstruct(img)
    #    self.imgs.append(img)
    #    self.titles.append(f'Epoch {epoch}')
    #    pbar.show_imgs(self.imgs, self.titles)
    #    return add_metrics(last_metrics, [getattr(self.smoothenerG,'smooth',None),getattr(self.smoothenerC,'smooth',None)])

    def switch(self, gen_mode=None):
        "Switch the model and loss function, if `gen_mode` is provided, in the desired mode."
        self.gen_mode = (not self.gen_mode) if gen_mode is None else gen_mode
        self._set_trainable()
        self.model.switch(gen_mode)
        self.loss_func.switch(gen_mode)
Warning

GANTrainer本身是没有用的,您需要通过以下开关之一来完成它。

class FixedGANSwitcher(Callback):
    "Switcher to do `n_crit` iterations of the critic then `n_gen` iterations of the generator."
    run_after = GANTrainer
    def __init__(self,
        n_crit:int=1, # 在切换到生成器之前,需要进行多少步的批评者训练
        n_gen:int=1 # 在切换到判别器之前,生成器需要训练多少步
    ):
        store_attr('n_crit,n_gen')

    def before_train(self): self.n_c,self.n_g = 0,0

    def after_batch(self):
        "Switch the model if necessary."
        if not self.training: return
        if self.learn.gan_trainer.gen_mode:
            self.n_g += 1
            n_iter,n_in,n_out = self.n_gen,self.n_c,self.n_g
        else:
            self.n_c += 1
            n_iter,n_in,n_out = self.n_crit,self.n_g,self.n_c
        target = n_iter if isinstance(n_iter, int) else n_iter(n_in)
        if target == n_out:
            self.learn.gan_trainer.switch()
            self.n_c,self.n_g = 0,0
class AdaptiveGANSwitcher(Callback):
    "Switcher that goes back to generator/critic when the loss goes below `gen_thresh`/`crit_thresh`."
    run_after = GANTrainer
    def __init__(self,
        gen_thresh:None|float=None, # 发电机损耗阈值
        critic_thresh:None|float=None # 评论家损失阈值
    ):
        store_attr('gen_thresh,critic_thresh')

    def after_batch(self):
        "Switch the model if necessary."
        if not self.training: return
        if self.gan_trainer.gen_mode:
            if self.gen_thresh is None or self.loss < self.gen_thresh: self.gan_trainer.switch()
        else:
            if self.critic_thresh is None or self.loss < self.critic_thresh: self.gan_trainer.switch()
class GANDiscriminativeLR(Callback):
    "`Callback` that handles multiplying the learning rate by `mult_lr` for the critic."
    run_after = GANTrainer
    def __init__(self, mult_lr=5.): self.mult_lr = mult_lr

    def before_batch(self):
        "Multiply the current lr if necessary."
        if not self.learn.gan_trainer.gen_mode and self.training:
            self.learn.opt.set_hyper('lr', self.learn.opt.hypers[0]['lr']*self.mult_lr)

    def after_batch(self):
        "Put the LR back to its value if necessary."
        if not self.learn.gan_trainer.gen_mode: self.learn.opt.set_hyper('lr', self.learn.opt.hypers[0]['lr']/self.mult_lr)

GAN 数据

class InvisibleTensor(TensorBase):
    "TensorBase but show method does nothing"
    def show(self, ctx=None, **kwargs): return ctx
def generate_noise(
    fn, # 虚拟参数,以便与 `DataBlock` 兼容
    size=100 # 返回噪声向量的尺寸
) -> InvisibleTensor:
    "Generate noise vector."
    return cast(torch.randn(size), InvisibleTensor)

我们使用 generate_noise 函数生成噪声向量,以传递给生成器进行图像生成。

@typedispatch
def show_batch(x:InvisibleTensor, y:TensorImage, samples, ctxs=None, max_n=10, nrows=None, ncols=None, figsize=None, **kwargs):
    if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)
    ctxs = show_batch[object](x, y, samples, ctxs=ctxs, max_n=max_n, **kwargs)
    return ctxs
@typedispatch
def show_results(x:InvisibleTensor, y:TensorImage, samples, outs, ctxs=None, max_n=10, nrows=None, ncols=None, figsize=None, **kwargs):
    if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)
    ctxs = [b.show(ctx=c, **kwargs) for b,c,_ in zip(outs.itemgot(0),ctxs,range(max_n))]
    return ctxs
bs = 128
size = 64
dblock = DataBlock(blocks = (TransformBlock, ImageBlock),
                   get_x = generate_noise,
                   get_items = get_image_files,
                   splitter = IndexSplitter([]),
                   item_tfms=Resize(size, method=ResizeMethod.Crop), 
                   batch_tfms = Normalize.from_stats(torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])))
path = untar_data(URLs.LSUN_BEDROOMS)
dls = dblock.dataloaders(path, path=path, bs=bs)
dls.show_batch(max_n=16)

GAN 学习器

def gan_loss_from_func(
    loss_gen:callable, # 生成器的损失函数。评估生成器输出图像与目标真实图像。
    loss_crit:callable, # 用于评价判别器的损失函数。评估真实图像和生成图像的预测结果。
    weights_gen:None|MutableSequence|tuple=None # 生成器和判别器损失函数的权重
):
    "Define loss functions for a GAN from `loss_gen` and `loss_crit`."
    def _loss_G(fake_pred, output, target, weights_gen=weights_gen):
        ones = fake_pred.new_ones(fake_pred.shape[0])
        weights_gen = ifnone(weights_gen, (1.,1.))
        return weights_gen[0] * loss_crit(fake_pred, ones) + weights_gen[1] * loss_gen(output, target)

    def _loss_C(real_pred, fake_pred):
        ones  = real_pred.new_ones (real_pred.shape[0])
        zeros = fake_pred.new_zeros(fake_pred.shape[0])
        return (loss_crit(real_pred, ones) + loss_crit(fake_pred, zeros)) / 2

    return _loss_G, _loss_C
def _tk_mean(fake_pred, output, target): return fake_pred.mean()
def _tk_diff(real_pred, fake_pred): return real_pred.mean() - fake_pred.mean()
@delegates()
class GANLearner(Learner):
    "A `Learner` suitable for GANs."
    def __init__(self,
        dls:DataLoaders, # 用于GAN数据的DataLoaders对象
        generator:nn.Module, # 发电机模型
        critic:nn.Module, # 批评模型
        gen_loss_func:callable, # 生成器损失函数
        crit_loss_func:callable, # 批评损失函数
        switcher:Callback|None=None, # 用于在生成器和判别器训练之间切换的回调函数,默认为 `FixedGANSwitcher`。
        gen_first:bool=False, # 无论我们从生成器训练开始
        switch_eval:bool=True, # 在计算损失时是否应将模型设置为评估模式
        show_img:bool=True, # 是否在训练过程中展示生成的示例图像
        clip:None|float=None, # 剪裁权重多少
        cbs:Callback|None|MutableSequence=None, # 其他回调函数
        metrics:None|MutableSequence|callable=None, # 指标
        **kwargs
    ):
        gan = GANModule(generator, critic)
        loss_func = GANLoss(gen_loss_func, crit_loss_func, gan)
        if switcher is None: switcher = FixedGANSwitcher()
        trainer = GANTrainer(clip=clip, switch_eval=switch_eval, gen_first=gen_first, show_img=show_img)
        cbs = L(cbs) + L(trainer, switcher)
        metrics = L(metrics) + L(*LossMetrics('gen_loss,crit_loss'))
        super().__init__(dls, gan, loss_func=loss_func, cbs=cbs, metrics=metrics, **kwargs)

    @classmethod
    def from_learners(cls,
        gen_learn:Learner, # 一个包含生成器的`Learner`对象
        crit_learn:Learner, # 一个包含评价器的`学习者`对象
        switcher:Callback|None=None, # 用于在生成器和判别器训练之间切换的回调函数,默认为 `FixedGANSwitcher`。
        weights_gen:None|MutableSequence|tuple=None, # 生成器和判别器损失函数的权重
        **kwargs
    ):
        "Create a GAN from `learn_gen` and `learn_crit`."
        losses = gan_loss_from_func(gen_learn.loss_func, crit_learn.loss_func, weights_gen=weights_gen)
        return cls(gen_learn.dls, gen_learn.model, crit_learn.model, *losses, switcher=switcher, **kwargs)

    @classmethod
    def wgan(cls,
        dls:DataLoaders, # 用于GAN数据的DataLoaders对象
        generator:nn.Module, # 发电机模型
        critic:nn.Module, # 批评模型
        switcher:Callback|None=None, # 用于在生成器和判别器训练之间切换的回调函数,默认为 `FixedGANSwitcher(n_crit=5, n_gen=1)`。
        clip:None|float=0.01, # 剪裁权重多少
        switch_eval:bool=False, # 在计算损失时是否应将模型设置为评估模式
        **kwargs
    ):
        "Create a [WGAN](https://arxiv.org/abs/1701.07875) from `dls`, `generator` and `critic`."
        if switcher is None: switcher = FixedGANSwitcher(n_crit=5, n_gen=1)
        return cls(dls, generator, critic, _tk_mean, _tk_diff, switcher=switcher, clip=clip, switch_eval=switch_eval, **kwargs)

GANLearner.from_learners = delegates(to=GANLearner.__init__)(GANLearner.from_learners)
GANLearner.wgan = delegates(to=GANLearner.__init__)(GANLearner.wgan)
show_doc(GANLearner.from_learners)

GANLearner.from_learners[source]

GANLearner.from_learners(gen_learn:Learner, crit_learn:Learner, switcher:Callback'>, None)=None, weights_gen:(None, <class 'list'>, <class 'tuple'>)=None, gen_first:bool=False, switch_eval:bool=True, show_img:bool=True, clip:(None, <class 'float'>)=None, cbs:Callback'>, None, <class 'list'>)=None, metrics:(None, <class 'list'>, <built-in function callable>)=None, loss_func=None, opt_func=Adam, lr=0.001, splitter=trainable_params, path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95, 0.85, 0.95))

Create a GAN from learn_gen and learn_crit.

Type Default Details
gen_learn Learner A Learner object that has the generator
crit_learn Learner A Learner object that has the critic
switcher (Callback, None) None Callback for switching between generator and critic training, defaults to FixedGANSwitcher
weights_gen (None, list, tuple) None Weights for the generator and critic loss function
gen_first bool False No Content
switch_eval bool True No Content
show_img bool True No Content
clip (None, float) None No Content
cbs (Callback, None, list) None No Content
metrics (None, list, callable) None No Content
loss_func NoneType None No Content
opt_func function <function Adam> No Content
lr float 0.001 No Content
splitter function <function trainable_params> No Content
path NoneType None No Content
model_dir str models No Content
wd NoneType None No Content
wd_bn_bias bool False No Content
train_bn bool True No Content
moms tuple (0.95, 0.85, 0.95) No Content
show_doc(GANLearner.wgan)

GANLearner.wgan[source]

GANLearner.wgan(dls:DataLoaders, generator:Module, critic:Module, switcher:Callback'>, None)=None, clip:(None, <class 'float'>)=0.01, switch_eval:bool=False, gen_first:bool=False, show_img:bool=True, cbs:Callback'>, None, <class 'list'>)=None, metrics:(None, <class 'list'>, <built-in function callable>)=None, loss_func=None, opt_func=Adam, lr=0.001, splitter=trainable_params, path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95, 0.85, 0.95))

Create a WGAN from dls, generator and critic.

Type Default Details
dls DataLoaders DataLoaders object for GAN data
generator Module Generator model
critic Module Critic model
switcher (Callback, None) None Callback for switching between generator and critic training, defaults to FixedGANSwitcher(n_crit=5, n_gen=1)
clip (None, float) 0.01 How much to clip the weights
switch_eval bool False Whether the model should be set to eval mode when calculating loss
gen_first bool False No Content
show_img bool True No Content
cbs (Callback, None, list) None No Content
metrics (None, list, callable) None No Content
loss_func NoneType None No Content
opt_func function <function Adam> No Content
lr float 0.001 No Content
splitter function <function trainable_params> No Content
path NoneType None No Content
model_dir str models No Content
wd NoneType None No Content
wd_bn_bias bool False No Content
train_bn bool True No Content
moms tuple (0.95, 0.85, 0.95) No Content
from fastai.callback.all import *
generator = basic_generator(64, n_channels=3, n_extra_layers=1)
critic    = basic_critic   (64, n_channels=3, n_extra_layers=1, act_cls=partial(nn.LeakyReLU, negative_slope=0.2))
learn = GANLearner.wgan(dls, generator, critic, opt_func = RMSProp)
learn.recorder.train_metrics=True
learn.recorder.valid_metrics=False
learn.fit(1, 2e-4, wd=0.)
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (generator) that exists in the learner. Use `self.learn.generator` to avoid this
  warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (critic) that exists in the learner. Use `self.learn.critic` to avoid this
  warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (gen_mode) that exists in the learner. Use `self.learn.gen_mode` to avoid this
  warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
epoch train_loss gen_loss crit_loss time
0 -0.815071 0.646809 -1.140522 00:38
/home/tmabraham/anaconda3/envs/fastai/lib/python3.7/site-packages/fastprogress/fastprogress.py:74: UserWarning: Your generator is empty.
  warn("Your generator is empty.")
learn.show_results(max_n=9, ds_idx=0)

导出 -

from nbdev import nbdev_export
nbdev_export()
Converted 00_torch_core.ipynb.
Converted 01_layers.ipynb.
Converted 01a_losses.ipynb.
Converted 02_data.load.ipynb.
Converted 03_data.core.ipynb.
Converted 04_data.external.ipynb.
Converted 05_data.transforms.ipynb.
Converted 06_data.block.ipynb.
Converted 07_vision.core.ipynb.
Converted 08_vision.data.ipynb.
Converted 09_vision.augment.ipynb.
Converted 09b_vision.utils.ipynb.
Converted 09c_vision.widgets.ipynb.
Converted 10_tutorial.pets.ipynb.
Converted 10b_tutorial.albumentations.ipynb.
Converted 11_vision.models.xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_callback.core.ipynb.
Converted 13a_learner.ipynb.
Converted 13b_metrics.ipynb.
Converted 14_callback.schedule.ipynb.
Converted 14a_callback.data.ipynb.
Converted 15_callback.hook.ipynb.
Converted 15a_vision.models.unet.ipynb.
Converted 16_callback.progress.ipynb.
Converted 17_callback.tracker.ipynb.
Converted 18_callback.fp16.ipynb.
Converted 18a_callback.training.ipynb.
Converted 18b_callback.preds.ipynb.
Converted 19_callback.mixup.ipynb.
Converted 20_interpret.ipynb.
Converted 20a_distributed.ipynb.
Converted 21_vision.learner.ipynb.
Converted 22_tutorial.imagenette.ipynb.
Converted 23_tutorial.vision.ipynb.
Converted 24_tutorial.image_sequence.ipynb.
Converted 24_tutorial.siamese.ipynb.
Converted 24_vision.gan.ipynb.
Converted 30_text.core.ipynb.
Converted 31_text.data.ipynb.
Converted 32_text.models.awdlstm.ipynb.
Converted 33_text.models.core.ipynb.
Converted 34_callback.rnn.ipynb.
Converted 35_tutorial.wikitext.ipynb.
Converted 37_text.learner.ipynb.
Converted 38_tutorial.text.ipynb.
Converted 39_tutorial.transformers.ipynb.
Converted 40_tabular.core.ipynb.
Converted 41_tabular.data.ipynb.
Converted 42_tabular.model.ipynb.
Converted 43_tabular.learner.ipynb.
Converted 44_tutorial.tabular.ipynb.
Converted 45_collab.ipynb.
Converted 46_tutorial.collab.ipynb.
Converted 50_tutorial.datablock.ipynb.
Converted 60_medical.imaging.ipynb.
Converted 61_tutorial.medical_imaging.ipynb.
Converted 65_medical.text.ipynb.
Converted 70_callback.wandb.ipynb.
Converted 71_callback.tensorboard.ipynb.
Converted 72_callback.neptune.ipynb.
Converted 73_callback.captum.ipynb.
Converted 74_callback.azureml.ipynb.
Converted 97_test_utils.ipynb.
Converted 99_pytorch_doc.ipynb.
Converted dev-setup.ipynb.
Converted app_examples.ipynb.
Converted camvid.ipynb.
Converted migrating_catalyst.ipynb.
Converted migrating_ignite.ipynb.
Converted migrating_lightning.ipynb.
Converted migrating_pytorch.ipynb.
Converted migrating_pytorch_verbose.ipynb.
Converted ulmfit.ipynb.
Converted index.ipynb.
Converted index_original.ipynb.
Converted quick_start.ipynb.
Converted tutorial.ipynb.