! [ -e /content ] && pip install -Uqq fastai # 在Colab上升级fastai
GAN
from __future__ import annotations
from fastai.basics import *
from fastai.vision.all import *
::: {#cell-3 .cell 0=‘d’ 1=‘e’ 2=‘f’ 3=‘a’ 4=‘u’ 5=‘l’ 6=‘t’ 7=’_’ 8=‘e’ 9=‘x’ 10=‘p’ 11=’ ’ 12=‘视’ 13=‘觉’ 14=‘.’ 15=‘生’ 16=‘成’ 17=‘对’ 18=‘抗’ 19=‘网’ 20=‘络’}
### 默认类级别 3
:::
from nbdev.showdoc import *
对生成对抗网络的基本支持
GAN 代表 生成对抗网络,由 Ian Goodfellow 发明。其概念是我们同时训练两个模型:一个生成器和一个鉴别器。生成器会尝试生成与数据集中相似的新图像,而鉴别器则会尝试区分真实图像和生成器生成的图像。生成器输出图像,鉴别器输出一个数字(通常是一个概率,真实图像为 1,假图像为 0)。
我们以相互对抗的方式训练它们,具体步骤如下(或多或少):
- 冻结生成器,训练鉴别器一步:
- 获取一批真实图像(我们称之为
real
) - 生成一批假图像(我们称之为
fake
) - 让鉴别器评估每一批,并计算损失函数;重要的是,它会对检测到真实图像给予正向奖励,而对假图像给予惩罚
- 用这个损失的梯度更新鉴别器的权重
- 冻结鉴别器,训练生成器一步:
- 生成一批假图像
- 在其上评估鉴别器
- 返回一个损失,正向奖励鉴别器认为这些是现实图像的情况
- 用这个损失的梯度更新生成器的权重
fastai库通过GANTrainer提供支持用于训练GANs,但不包含超过基本模型的内容。
封装模块
class GANModule(Module):
"Wrapper around a `generator` and a `critic` to create a GAN."
def __init__(self,
=None, # 生成器 PyTorch 模块
generator:nn.Module=None, # 判别器 PyTorch 模块
critic:nn.ModuleNone|bool=False # 是否应将GAN设置为生成器模式
gen_mode:
):if generator is not None: self.generator=generator
if critic is not None: self.critic =critic
'gen_mode')
store_attr(
def forward(self, *args):
return self.generator(*args) if self.gen_mode else self.critic(*args)
def switch(self,
None|bool=None # 是否应将GAN设置为生成器模式
gen_mode:
):"Put the module in generator mode if `gen_mode` is `True`, in critic mode otherwise."
self.gen_mode = (not self.gen_mode) if gen_mode is None else gen_mode
这只是一个包含两个模型的外壳。当被调用时,它将根据gen_mode
的值将输入委托给generator
或critic
。
show_doc(GANModule.switch)
GANModule.switch
[source]
GANModule.switch
(gen_mode
:(None, <class 'bool'>)
=None
)
Put the module in generator mode if gen_mode
is True
, in critic mode otherwise.
Type | Default | Details | |
---|---|---|---|
gen_mode |
(None, bool) |
None |
Whether the GAN should be set to generator mode |
默认情况下(将 gen_mode
留空为 None
),这将使模块进入另一种模式(如果它处于生成器模式,则进入评判者模式,反之亦然)。
@delegates(ConvLayer.__init__)
def basic_critic(
int, # 批评者的输入尺寸(与生成器的输出尺寸相同)
in_size:int, # 评论者的输入通道数
n_channels:int=64, # 评论中使用的特征数量
n_features:int=0, # 评论者中额外的隐藏层数量
n_extra_layers:=NormType.Batch, # 在评论者中使用的归一化类型
norm_type:NormType**kwargs
-> nn.Sequential:
) "A basic critic for images `n_channels` x `in_size` x `in_size`."
= [ConvLayer(n_channels, n_features, 4, 2, 1, norm_type=None, **kwargs)]
layers = in_size//2, n_features
cur_size, cur_ftrs += [ConvLayer(cur_ftrs, cur_ftrs, 3, 1, norm_type=norm_type, **kwargs) for _ in range(n_extra_layers)]
layers while cur_size > 4:
*2, 4, 2, 1, norm_type=norm_type, **kwargs))
layers.append(ConvLayer(cur_ftrs, cur_ftrs*= 2 ; cur_size //= 2
cur_ftrs = kwargs.get('init', nn.init.kaiming_normal_)
init += [init_default(nn.Conv2d(cur_ftrs, 1, 4, padding=0), init), Flatten()]
layers return nn.Sequential(*layers)
class AddChannels(Module):
"Add `n_dim` channels at the end of the input."
def __init__(self, n_dim): self.n_dim=n_dim
def forward(self, x): return x.view(*(list(x.shape)+[1]*self.n_dim))
@delegates(ConvLayer.__init__)
def basic_generator(
int, # 生成器的输出尺寸(与判别器的输入尺寸相同)
out_size:int, # 生成器输出通道的数量
n_channels:int=100, # 生成器输入噪声向量的尺寸
in_sz:int=64, # 生成器中使用的特征数量
n_features:int=0, # 生成器中额外隐藏层的数量
n_extra_layers:**kwargs
-> nn.Sequential:
) "A basic generator from `in_sz` to images `n_channels` x `out_size` x `out_size`."
= 4, n_features//2
cur_size, cur_ftrs while cur_size < out_size: cur_size *= 2; cur_ftrs *= 2
= [AddChannels(2), ConvLayer(in_sz, cur_ftrs, 4, 1, transpose=True, **kwargs)]
layers = 4
cur_size while cur_size < out_size // 2:
//2, 4, 2, 1, transpose=True, **kwargs))
layers.append(ConvLayer(cur_ftrs, cur_ftrs//= 2; cur_size *= 2
cur_ftrs += [ConvLayer(cur_ftrs, cur_ftrs, 3, 1, 1, transpose=True, **kwargs) for _ in range(n_extra_layers)]
layers += [nn.ConvTranspose2d(cur_ftrs, n_channels, 4, 2, 1, bias=False), nn.Tanh()]
layers return nn.Sequential(*layers)
= basic_critic(64, 3)
critic = basic_generator(64, 3)
generator = GANModule(critic=critic, generator=generator)
tst = torch.randn(2, 3, 64, 64)
real = tst(real)
real_p 2,1])
test_eq(real_p.shape, [
#tst 现已进入生成器模式
tst.switch() = torch.randn(2, 100)
noise = tst(noise)
fake
test_eq(fake.shape, real.shape)
#tst 已重新进入评论模式
tst.switch() = tst(fake)
fake_p 2,1]) test_eq(fake_p.shape, [
= dict(act_cls = partial(nn.LeakyReLU, negative_slope=0.2), norm_type=NormType.Spectral)
_conv_args
def _conv(ni, nf, ks=3, stride=1, self_attention=False, **kwargs):
if self_attention: kwargs['xtra'] = SelfAttention(nf)
return ConvLayer(ni, nf, ks=ks, stride=stride, **_conv_args, **kwargs)
@delegates(ConvLayer)
def DenseResBlock(
int, # 特征数量
nf:=NormType.Batch, # 归一化类型
norm_type:NormType**kwargs
-> SequentialEx:
) "Resnet block of `nf` features. `conv_kwargs` are passed to `conv_layer`."
return SequentialEx(ConvLayer(nf, nf, norm_type=norm_type, **kwargs),
=norm_type, **kwargs),
ConvLayer(nf, nf, norm_type=True)) MergeLayer(dense
def gan_critic(
int=3, # 评论者的输入通道数
n_channels:int=128, # 评论家特征数量
nf:int=3, # 判别器中ResNet块的数量
n_blocks:float=0.15 # 评论家中的丢弃量
p:-> nn.Sequential:
) "Critic to train a `GAN`."
= [
layers =4, stride=2),
_conv(n_channels, nf, ks/2),
nn.Dropout2d(p**_conv_args)]
DenseResBlock(nf, *= 2 # 经过密集区块
nf for i in range(n_blocks):
+= [
layers
nn.Dropout2d(p),*2, ks=4, stride=2, self_attention=(i==0))]
_conv(nf, nf*= 2
nf += [
layers 1, ks=4, bias=False, padding=0, norm_type=NormType.Spectral, act_cls=None),
ConvLayer(nf,
Flatten()]return nn.Sequential(*layers)
class GANLoss(GANModule):
"Wrapper around `crit_loss_func` and `gen_loss_func`"
def __init__(self,
callable, # 生成器损失函数
gen_loss_func:callable, # 批评损失函数
crit_loss_func:# GAN模型
gan_model:GANModule
):super().__init__()
'gen_loss_func,crit_loss_func,gan_model')
store_attr(
def generator(self,
# 发电机输出
output, # 实像
target
):"Evaluate the `output` with the critic then uses `self.gen_loss_func` to evaluate how well the critic was fooled by `output`"
= self.gan_model.critic(output)
fake_pred self.gen_loss = self.gen_loss_func(fake_pred, output, target)
return self.gen_loss
def critic(self,
# 真实图像的评论预测
real_pred, input # 输入噪声向量以传递给生成器
):"Create some `fake_pred` with the generator from `input` and compare them to `real_pred` in `self.crit_loss_func`."
= self.gan_model.generator(input).requires_grad_(False)
fake = self.gan_model.critic(fake)
fake_pred self.crit_loss = self.crit_loss_func(real_pred, fake_pred)
return self.crit_loss
show_doc(GANLoss.generator)
GANLoss.generator
[source]
GANLoss.generator
(output
,target
)
Evaluate the output
with the critic then uses self.gen_loss_func
to evaluate how well the critic was fooled by output
Type | Default | Details | |
---|---|---|---|
output |
Generator outputs | ||
target |
Real images |
show_doc(GANLoss.critic)
GANLoss.critic
[source]
GANLoss.critic
(real_pred
,input
)
Create some fake_pred
with the generator from input
and compare them to real_pred
in self.crit_loss_func
.
Type | Default | Details | |
---|---|---|---|
real_pred |
Critic predictions for real images | ||
input |
Input noise vector to pass into generator |
如果调用generator
方法,则该损失函数期望接收生成器的output
和一些target
(一批真实图像)。它将使用gen_loss_func
评估生成器是否成功欺骗了鉴别器。该损失函数具有以下签名
def gen_loss_func(fake_pred, output, target):
以便能够将鉴别器对output
的输出(第一个参数fake_pred
)与output
和target
结合起来(例如,如果你想将GAN损失与其他损失混合)。
如果调用critic
方法,则该损失函数期望接收鉴别器给出的real_pred
和一些input
(馈送给生成器的噪声)。它将使用crit_loss_func
评估鉴别器。该损失函数具有以下签名
def crit_loss_func(real_pred, fake_pred):
其中real_pred
是鉴别器对一批真实图像的输出,而fake_pred
是通过生成器从噪声生成的。
class AdaptiveLoss(Module):
"Expand the `target` to match the `output` size before applying `crit`."
def __init__(self, crit:callable): self.crit = crit
def forward(self, output:Tensor, target:Tensor):
return self.crit(output, target[:,None].expand_as(output).float())
def accuracy_thresh_expand(y_pred:Tensor, y_true:Tensor, thresh:float=0.5, sigmoid:bool=True):
"Compute thresholded accuracy after expanding `y_true` to the size of `y_pred`."
if sigmoid: y_pred = y_pred.sigmoid()
return ((y_pred>thresh).byte()==y_true[:,None].expand_as(y_pred).byte()).float().mean()
GAN训练的回调函数
def set_freeze_model(
# 模型冻结/解冻
m:nn.Module, bool # `requires_grad` 参数。设置为 `True` 表示冻结。
rg:
):for p in m.parameters(): p.requires_grad_(rg)
class GANTrainer(Callback):
"Callback to handle GAN Training."
= TrainEvalCallback
run_after def __init__(self,
bool=False, # 在计算损失时是否应将模型设置为评估模式
switch_eval:None|float=None, # 剪掉多少权重
clip:float=0.98, # 损失的指数加权平滑参数 `beta`
beta:bool=False, # 无论我们从生成器训练开始
gen_first:bool=True, # 是否在训练过程中展示生成的示例图像
show_img:
):'switch_eval,clip,gen_first,show_img')
store_attr(self.gen_loss,self.crit_loss = AvgSmoothLoss(beta=beta),AvgSmoothLoss(beta=beta)
def _set_trainable(self):
"Appropriately set the generator and critic into a trainable or loss evaluation mode based on `self.gen_mode`."
= self.generator if self.gen_mode else self.critic
train_model = self.generator if not self.gen_mode else self.critic
loss_model True)
set_freeze_model(train_model, False)
set_freeze_model(loss_model, if self.switch_eval:
train_model.train()eval()
loss_model.
def before_fit(self):
"Initialization."
self.generator,self.critic = self.model.generator,self.model.critic
self.gen_mode = self.gen_first
self.switch(self.gen_mode)
self.crit_losses,self.gen_losses = [],[]
self.gen_loss.reset() ; self.crit_loss.reset()
#self.recorder.no_val = True
#self.recorder.add_metric_names(['gen_loss', 'disc_loss'])
#self.imgs, self.titles = [], []
def before_validate(self):
"Switch in generator mode for showing results."
self.switch(gen_mode=True)
def before_batch(self):
"Clamp the weights with `self.clip` if it's not None, set the correct input/target."
if self.training and self.clip is not None:
for p in self.critic.parameters(): p.data.clamp_(-self.clip, self.clip)
if not self.gen_mode:
self.learn.xb,self.learn.yb) = (self.yb,self.xb)
(
def after_batch(self):
"Record `last_loss` in the proper list."
if not self.training: return
if self.gen_mode:
self.gen_loss.accumulate(self.learn)
self.gen_losses.append(self.gen_loss.value)
self.last_gen = self.learn.to_detach(self.pred)
else:
self.crit_loss.accumulate(self.learn)
self.crit_losses.append(self.crit_loss.value)
def before_epoch(self):
"Put the critic or the generator back to eval if necessary."
self.switch(self.gen_mode)
#def after_epoch(self):
# "Show a sample image."
# if not hasattr(self, 'last_gen') or not self.show_img: return
# data = self.learn.data
# img = self.last_gen[0]
# norm = getattr(data,'norm',False)
# if norm and norm.keywords.get('do_y',False): img = data.denorm(img)
# img = data.train_ds.y.reconstruct(img)
# self.imgs.append(img)
# self.titles.append(f'Epoch {epoch}')
# pbar.show_imgs(self.imgs, self.titles)
# return add_metrics(last_metrics, [getattr(self.smoothenerG,'smooth',None),getattr(self.smoothenerC,'smooth',None)])
def switch(self, gen_mode=None):
"Switch the model and loss function, if `gen_mode` is provided, in the desired mode."
self.gen_mode = (not self.gen_mode) if gen_mode is None else gen_mode
self._set_trainable()
self.model.switch(gen_mode)
self.loss_func.switch(gen_mode)
GANTrainer本身是没有用的,您需要通过以下开关之一来完成它。
class FixedGANSwitcher(Callback):
"Switcher to do `n_crit` iterations of the critic then `n_gen` iterations of the generator."
= GANTrainer
run_after def __init__(self,
int=1, # 在切换到生成器之前,需要进行多少步的批评者训练
n_crit:int=1 # 在切换到判别器之前,生成器需要训练多少步
n_gen:
):'n_crit,n_gen')
store_attr(
def before_train(self): self.n_c,self.n_g = 0,0
def after_batch(self):
"Switch the model if necessary."
if not self.training: return
if self.learn.gan_trainer.gen_mode:
self.n_g += 1
= self.n_gen,self.n_c,self.n_g
n_iter,n_in,n_out else:
self.n_c += 1
= self.n_crit,self.n_g,self.n_c
n_iter,n_in,n_out = n_iter if isinstance(n_iter, int) else n_iter(n_in)
target if target == n_out:
self.learn.gan_trainer.switch()
self.n_c,self.n_g = 0,0
class AdaptiveGANSwitcher(Callback):
"Switcher that goes back to generator/critic when the loss goes below `gen_thresh`/`crit_thresh`."
= GANTrainer
run_after def __init__(self,
None|float=None, # 发电机损耗阈值
gen_thresh:None|float=None # 评论家损失阈值
critic_thresh:
):'gen_thresh,critic_thresh')
store_attr(
def after_batch(self):
"Switch the model if necessary."
if not self.training: return
if self.gan_trainer.gen_mode:
if self.gen_thresh is None or self.loss < self.gen_thresh: self.gan_trainer.switch()
else:
if self.critic_thresh is None or self.loss < self.critic_thresh: self.gan_trainer.switch()
class GANDiscriminativeLR(Callback):
"`Callback` that handles multiplying the learning rate by `mult_lr` for the critic."
= GANTrainer
run_after def __init__(self, mult_lr=5.): self.mult_lr = mult_lr
def before_batch(self):
"Multiply the current lr if necessary."
if not self.learn.gan_trainer.gen_mode and self.training:
self.learn.opt.set_hyper('lr', self.learn.opt.hypers[0]['lr']*self.mult_lr)
def after_batch(self):
"Put the LR back to its value if necessary."
if not self.learn.gan_trainer.gen_mode: self.learn.opt.set_hyper('lr', self.learn.opt.hypers[0]['lr']/self.mult_lr)
GAN 数据
class InvisibleTensor(TensorBase):
"TensorBase but show method does nothing"
def show(self, ctx=None, **kwargs): return ctx
def generate_noise(
# 虚拟参数,以便与 `DataBlock` 兼容
fn, =100 # 返回噪声向量的尺寸
size-> InvisibleTensor:
) "Generate noise vector."
return cast(torch.randn(size), InvisibleTensor)
我们使用 generate_noise
函数生成噪声向量,以传递给生成器进行图像生成。
@typedispatch
def show_batch(x:InvisibleTensor, y:TensorImage, samples, ctxs=None, max_n=10, nrows=None, ncols=None, figsize=None, **kwargs):
if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)
= show_batch[object](x, y, samples, ctxs=ctxs, max_n=max_n, **kwargs)
ctxs return ctxs
@typedispatch
def show_results(x:InvisibleTensor, y:TensorImage, samples, outs, ctxs=None, max_n=10, nrows=None, ncols=None, figsize=None, **kwargs):
if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)
= [b.show(ctx=c, **kwargs) for b,c,_ in zip(outs.itemgot(0),ctxs,range(max_n))]
ctxs return ctxs
= 128
bs = 64 size
= DataBlock(blocks = (TransformBlock, ImageBlock),
dblock = generate_noise,
get_x = get_image_files,
get_items = IndexSplitter([]),
splitter =Resize(size, method=ResizeMethod.Crop),
item_tfms= Normalize.from_stats(torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5]))) batch_tfms
= untar_data(URLs.LSUN_BEDROOMS) path
= dblock.dataloaders(path, path=path, bs=bs) dls
=16) dls.show_batch(max_n
GAN 学习器
def gan_loss_from_func(
callable, # 生成器的损失函数。评估生成器输出图像与目标真实图像。
loss_gen:callable, # 用于评价判别器的损失函数。评估真实图像和生成图像的预测结果。
loss_crit:None|MutableSequence|tuple=None # 生成器和判别器损失函数的权重
weights_gen:
):"Define loss functions for a GAN from `loss_gen` and `loss_crit`."
def _loss_G(fake_pred, output, target, weights_gen=weights_gen):
= fake_pred.new_ones(fake_pred.shape[0])
ones = ifnone(weights_gen, (1.,1.))
weights_gen return weights_gen[0] * loss_crit(fake_pred, ones) + weights_gen[1] * loss_gen(output, target)
def _loss_C(real_pred, fake_pred):
= real_pred.new_ones (real_pred.shape[0])
ones = fake_pred.new_zeros(fake_pred.shape[0])
zeros return (loss_crit(real_pred, ones) + loss_crit(fake_pred, zeros)) / 2
return _loss_G, _loss_C
def _tk_mean(fake_pred, output, target): return fake_pred.mean()
def _tk_diff(real_pred, fake_pred): return real_pred.mean() - fake_pred.mean()
@delegates()
class GANLearner(Learner):
"A `Learner` suitable for GANs."
def __init__(self,
# 用于GAN数据的DataLoaders对象
dls:DataLoaders, # 发电机模型
generator:nn.Module, # 批评模型
critic:nn.Module, callable, # 生成器损失函数
gen_loss_func:callable, # 批评损失函数
crit_loss_func:|None=None, # 用于在生成器和判别器训练之间切换的回调函数,默认为 `FixedGANSwitcher`。
switcher:Callbackbool=False, # 无论我们从生成器训练开始
gen_first:bool=True, # 在计算损失时是否应将模型设置为评估模式
switch_eval:bool=True, # 是否在训练过程中展示生成的示例图像
show_img:None|float=None, # 剪裁权重多少
clip:|None|MutableSequence=None, # 其他回调函数
cbs:CallbackNone|MutableSequence|callable=None, # 指标
metrics:**kwargs
):= GANModule(generator, critic)
gan = GANLoss(gen_loss_func, crit_loss_func, gan)
loss_func if switcher is None: switcher = FixedGANSwitcher()
= GANTrainer(clip=clip, switch_eval=switch_eval, gen_first=gen_first, show_img=show_img)
trainer = L(cbs) + L(trainer, switcher)
cbs = L(metrics) + L(*LossMetrics('gen_loss,crit_loss'))
metrics super().__init__(dls, gan, loss_func=loss_func, cbs=cbs, metrics=metrics, **kwargs)
@classmethod
def from_learners(cls,
# 一个包含生成器的`Learner`对象
gen_learn:Learner, # 一个包含评价器的`学习者`对象
crit_learn:Learner, |None=None, # 用于在生成器和判别器训练之间切换的回调函数,默认为 `FixedGANSwitcher`。
switcher:CallbackNone|MutableSequence|tuple=None, # 生成器和判别器损失函数的权重
weights_gen:**kwargs
):"Create a GAN from `learn_gen` and `learn_crit`."
= gan_loss_from_func(gen_learn.loss_func, crit_learn.loss_func, weights_gen=weights_gen)
losses return cls(gen_learn.dls, gen_learn.model, crit_learn.model, *losses, switcher=switcher, **kwargs)
@classmethod
def wgan(cls,
# 用于GAN数据的DataLoaders对象
dls:DataLoaders, # 发电机模型
generator:nn.Module, # 批评模型
critic:nn.Module, |None=None, # 用于在生成器和判别器训练之间切换的回调函数,默认为 `FixedGANSwitcher(n_crit=5, n_gen=1)`。
switcher:CallbackNone|float=0.01, # 剪裁权重多少
clip:bool=False, # 在计算损失时是否应将模型设置为评估模式
switch_eval:**kwargs
):"Create a [WGAN](https://arxiv.org/abs/1701.07875) from `dls`, `generator` and `critic`."
if switcher is None: switcher = FixedGANSwitcher(n_crit=5, n_gen=1)
return cls(dls, generator, critic, _tk_mean, _tk_diff, switcher=switcher, clip=clip, switch_eval=switch_eval, **kwargs)
= delegates(to=GANLearner.__init__)(GANLearner.from_learners)
GANLearner.from_learners = delegates(to=GANLearner.__init__)(GANLearner.wgan) GANLearner.wgan
show_doc(GANLearner.from_learners)
GANLearner.from_learners
[source]
GANLearner.from_learners
(gen_learn
:Learner
,crit_learn
:Learner
,switcher
:Callback'>, None)
=None
,weights_gen
:(None, <class 'list'>, <class 'tuple'>)
=None
,gen_first
:bool
=False
,switch_eval
:bool
=True
,show_img
:bool
=True
,clip
:(None, <class 'float'>)
=None
,cbs
:Callback'>, None, <class 'list'>)
=None
,metrics
:(None, <class 'list'>, <built-in function callable>)
=None
,loss_func
=None
,opt_func
=Adam
,lr
=0.001
,splitter
=trainable_params
,path
=None
,model_dir
='models'
,wd
=None
,wd_bn_bias
=False
,train_bn
=True
,moms
=(0.95, 0.85, 0.95)
)
Create a GAN from learn_gen
and learn_crit
.
Type | Default | Details | |
---|---|---|---|
gen_learn |
Learner |
A Learner object that has the generator |
|
crit_learn |
Learner |
A Learner object that has the critic |
|
switcher |
(Callback, None) |
None |
Callback for switching between generator and critic training, defaults to FixedGANSwitcher |
weights_gen |
(None, list, tuple) |
None |
Weights for the generator and critic loss function |
gen_first |
bool |
False |
No Content |
switch_eval |
bool |
True |
No Content |
show_img |
bool |
True |
No Content |
clip |
(None, float) |
None |
No Content |
cbs |
(Callback, None, list) |
None |
No Content |
metrics |
(None, list, callable) |
None |
No Content |
loss_func |
NoneType |
None |
No Content |
opt_func |
function |
<function Adam> |
No Content |
lr |
float |
0.001 |
No Content |
splitter |
function |
<function trainable_params> |
No Content |
path |
NoneType |
None |
No Content |
model_dir |
str |
models |
No Content |
wd |
NoneType |
None |
No Content |
wd_bn_bias |
bool |
False |
No Content |
train_bn |
bool |
True |
No Content |
moms |
tuple |
(0.95, 0.85, 0.95) |
No Content |
show_doc(GANLearner.wgan)
GANLearner.wgan
[source]
GANLearner.wgan
(dls
:DataLoaders
,generator
:Module
,critic
:Module
,switcher
:Callback'>, None)
=None
,clip
:(None, <class 'float'>)
=0.01
,switch_eval
:bool
=False
,gen_first
:bool
=False
,show_img
:bool
=True
,cbs
:Callback'>, None, <class 'list'>)
=None
,metrics
:(None, <class 'list'>, <built-in function callable>)
=None
,loss_func
=None
,opt_func
=Adam
,lr
=0.001
,splitter
=trainable_params
,path
=None
,model_dir
='models'
,wd
=None
,wd_bn_bias
=False
,train_bn
=True
,moms
=(0.95, 0.85, 0.95)
)
Create a WGAN from dls
, generator
and critic
.
Type | Default | Details | |
---|---|---|---|
dls |
DataLoaders |
DataLoaders object for GAN data | |
generator |
Module |
Generator model | |
critic |
Module |
Critic model | |
switcher |
(Callback, None) |
None |
Callback for switching between generator and critic training, defaults to FixedGANSwitcher(n_crit=5, n_gen=1) |
clip |
(None, float) |
0.01 |
How much to clip the weights |
switch_eval |
bool |
False |
Whether the model should be set to eval mode when calculating loss |
gen_first |
bool |
False |
No Content |
show_img |
bool |
True |
No Content |
cbs |
(Callback, None, list) |
None |
No Content |
metrics |
(None, list, callable) |
None |
No Content |
loss_func |
NoneType |
None |
No Content |
opt_func |
function |
<function Adam> |
No Content |
lr |
float |
0.001 |
No Content |
splitter |
function |
<function trainable_params> |
No Content |
path |
NoneType |
None |
No Content |
model_dir |
str |
models |
No Content |
wd |
NoneType |
None |
No Content |
wd_bn_bias |
bool |
False |
No Content |
train_bn |
bool |
True |
No Content |
moms |
tuple |
(0.95, 0.85, 0.95) |
No Content |
from fastai.callback.all import *
= basic_generator(64, n_channels=3, n_extra_layers=1)
generator = basic_critic (64, n_channels=3, n_extra_layers=1, act_cls=partial(nn.LeakyReLU, negative_slope=0.2)) critic
= GANLearner.wgan(dls, generator, critic, opt_func = RMSProp) learn
=True
learn.recorder.train_metrics=False learn.recorder.valid_metrics
1, 2e-4, wd=0.) learn.fit(
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (generator) that exists in the learner. Use `self.learn.generator` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (critic) that exists in the learner. Use `self.learn.critic` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (gen_mode) that exists in the learner. Use `self.learn.gen_mode` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
epoch | train_loss | gen_loss | crit_loss | time |
---|---|---|---|---|
0 | -0.815071 | 0.646809 | -1.140522 | 00:38 |
/home/tmabraham/anaconda3/envs/fastai/lib/python3.7/site-packages/fastprogress/fastprogress.py:74: UserWarning: Your generator is empty.
warn("Your generator is empty.")
=9, ds_idx=0) learn.show_results(max_n
导出 -
from nbdev import nbdev_export
nbdev_export()
Converted 00_torch_core.ipynb.
Converted 01_layers.ipynb.
Converted 01a_losses.ipynb.
Converted 02_data.load.ipynb.
Converted 03_data.core.ipynb.
Converted 04_data.external.ipynb.
Converted 05_data.transforms.ipynb.
Converted 06_data.block.ipynb.
Converted 07_vision.core.ipynb.
Converted 08_vision.data.ipynb.
Converted 09_vision.augment.ipynb.
Converted 09b_vision.utils.ipynb.
Converted 09c_vision.widgets.ipynb.
Converted 10_tutorial.pets.ipynb.
Converted 10b_tutorial.albumentations.ipynb.
Converted 11_vision.models.xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_callback.core.ipynb.
Converted 13a_learner.ipynb.
Converted 13b_metrics.ipynb.
Converted 14_callback.schedule.ipynb.
Converted 14a_callback.data.ipynb.
Converted 15_callback.hook.ipynb.
Converted 15a_vision.models.unet.ipynb.
Converted 16_callback.progress.ipynb.
Converted 17_callback.tracker.ipynb.
Converted 18_callback.fp16.ipynb.
Converted 18a_callback.training.ipynb.
Converted 18b_callback.preds.ipynb.
Converted 19_callback.mixup.ipynb.
Converted 20_interpret.ipynb.
Converted 20a_distributed.ipynb.
Converted 21_vision.learner.ipynb.
Converted 22_tutorial.imagenette.ipynb.
Converted 23_tutorial.vision.ipynb.
Converted 24_tutorial.image_sequence.ipynb.
Converted 24_tutorial.siamese.ipynb.
Converted 24_vision.gan.ipynb.
Converted 30_text.core.ipynb.
Converted 31_text.data.ipynb.
Converted 32_text.models.awdlstm.ipynb.
Converted 33_text.models.core.ipynb.
Converted 34_callback.rnn.ipynb.
Converted 35_tutorial.wikitext.ipynb.
Converted 37_text.learner.ipynb.
Converted 38_tutorial.text.ipynb.
Converted 39_tutorial.transformers.ipynb.
Converted 40_tabular.core.ipynb.
Converted 41_tabular.data.ipynb.
Converted 42_tabular.model.ipynb.
Converted 43_tabular.learner.ipynb.
Converted 44_tutorial.tabular.ipynb.
Converted 45_collab.ipynb.
Converted 46_tutorial.collab.ipynb.
Converted 50_tutorial.datablock.ipynb.
Converted 60_medical.imaging.ipynb.
Converted 61_tutorial.medical_imaging.ipynb.
Converted 65_medical.text.ipynb.
Converted 70_callback.wandb.ipynb.
Converted 71_callback.tensorboard.ipynb.
Converted 72_callback.neptune.ipynb.
Converted 73_callback.captum.ipynb.
Converted 74_callback.azureml.ipynb.
Converted 97_test_utils.ipynb.
Converted 99_pytorch_doc.ipynb.
Converted dev-setup.ipynb.
Converted app_examples.ipynb.
Converted camvid.ipynb.
Converted migrating_catalyst.ipynb.
Converted migrating_ignite.ipynb.
Converted migrating_lightning.ipynb.
Converted migrating_pytorch.ipynb.
Converted migrating_pytorch_verbose.ipynb.
Converted ulmfit.ipynb.
Converted index.ipynb.
Converted index_original.ipynb.
Converted quick_start.ipynb.
Converted tutorial.ipynb.