! [ -e /content ] && pip install -Uqq fastai # 在Colab上升级fastai
from __future__ import annotations
from fastai.basics import *
from fastai.vision.all import *
### 默认类级别 3
from nbdev.showdoc import *
GAN 代表 生成对抗网络,由 Ian Goodfellow 发明。其概念是我们同时训练两个模型:一个生成器和一个鉴别器。生成器会尝试生成与数据集中相似的新图像,而鉴别器则会尝试区分真实图像和生成器生成的图像。生成器输出图像,鉴别器输出一个数字(通常是一个概率,真实图像为 1,假图像为 0)。
- 冻结生成器,训练鉴别器一步:
- 获取一批真实图像(我们称之为
) - 生成一批假图像(我们称之为
) - 让鉴别器评估每一批,并计算损失函数;重要的是,它会对检测到真实图像给予正向奖励,而对假图像给予惩罚
- 用这个损失的梯度更新鉴别器的权重
- 冻结鉴别器,训练生成器一步:
- 生成一批假图像
- 在其上评估鉴别器
- 返回一个损失,正向奖励鉴别器认为这些是现实图像的情况
- 用这个损失的梯度更新生成器的权重
class GANModule(Module):
"Wrapper around a `generator` and a `critic` to create a GAN."
def __init__(self,
=None, # 生成器 PyTorch 模块
generator:nn.Module=None, # 判别器 PyTorch 模块
critic:nn.ModuleNone|bool=False # 是否应将GAN设置为生成器模式
):if generator is not None: self.generator=generator
if critic is not None: self.critic =critic
def forward(self, *args):
return self.generator(*args) if self.gen_mode else self.critic(*args)
def switch(self,
None|bool=None # 是否应将GAN设置为生成器模式
):"Put the module in generator mode if `gen_mode` is `True`, in critic mode otherwise."
self.gen_mode = (not self.gen_mode) if gen_mode is None else gen_mode
默认情况下(将 gen_mode
留空为 None
def basic_critic(
int, # 批评者的输入尺寸(与生成器的输出尺寸相同)
in_size:int, # 评论者的输入通道数
n_channels:int=64, # 评论中使用的特征数量
n_features:int=0, # 评论者中额外的隐藏层数量
n_extra_layers:=NormType.Batch, # 在评论者中使用的归一化类型
-> nn.Sequential:
) "A basic critic for images `n_channels` x `in_size` x `in_size`."
= [ConvLayer(n_channels, n_features, 4, 2, 1, norm_type=None, **kwargs)]
layers = in_size//2, n_features
cur_size, cur_ftrs += [ConvLayer(cur_ftrs, cur_ftrs, 3, 1, norm_type=norm_type, **kwargs) for _ in range(n_extra_layers)]
layers while cur_size > 4:
*2, 4, 2, 1, norm_type=norm_type, **kwargs))
layers.append(ConvLayer(cur_ftrs, cur_ftrs*= 2 ; cur_size //= 2
cur_ftrs = kwargs.get('init', nn.init.kaiming_normal_)
init += [init_default(nn.Conv2d(cur_ftrs, 1, 4, padding=0), init), Flatten()]
layers return nn.Sequential(*layers)
class AddChannels(Module):
"Add `n_dim` channels at the end of the input."
def __init__(self, n_dim): self.n_dim=n_dim
def forward(self, x): return x.view(*(list(x.shape)+[1]*self.n_dim))
def basic_generator(
int, # 生成器的输出尺寸(与判别器的输入尺寸相同)
out_size:int, # 生成器输出通道的数量
n_channels:int=100, # 生成器输入噪声向量的尺寸
in_sz:int=64, # 生成器中使用的特征数量
n_features:int=0, # 生成器中额外隐藏层的数量
-> nn.Sequential:
) "A basic generator from `in_sz` to images `n_channels` x `out_size` x `out_size`."
= 4, n_features//2
cur_size, cur_ftrs while cur_size < out_size: cur_size *= 2; cur_ftrs *= 2
= [AddChannels(2), ConvLayer(in_sz, cur_ftrs, 4, 1, transpose=True, **kwargs)]
layers = 4
cur_size while cur_size < out_size // 2:
//2, 4, 2, 1, transpose=True, **kwargs))
layers.append(ConvLayer(cur_ftrs, cur_ftrs//= 2; cur_size *= 2
cur_ftrs += [ConvLayer(cur_ftrs, cur_ftrs, 3, 1, 1, transpose=True, **kwargs) for _ in range(n_extra_layers)]
layers += [nn.ConvTranspose2d(cur_ftrs, n_channels, 4, 2, 1, bias=False), nn.Tanh()]
layers return nn.Sequential(*layers)
= basic_critic(64, 3)
critic = basic_generator(64, 3)
generator = GANModule(critic=critic, generator=generator)
tst = torch.randn(2, 3, 64, 64)
real = tst(real)
real_p 2,1])
test_eq(real_p.shape, [
#tst 现已进入生成器模式
tst.switch() = torch.randn(2, 100)
noise = tst(noise)
test_eq(fake.shape, real.shape)
#tst 已重新进入评论模式
tst.switch() = tst(fake)
fake_p 2,1]) test_eq(fake_p.shape, [
= dict(act_cls = partial(nn.LeakyReLU, negative_slope=0.2), norm_type=NormType.Spectral)
def _conv(ni, nf, ks=3, stride=1, self_attention=False, **kwargs):
if self_attention: kwargs['xtra'] = SelfAttention(nf)
return ConvLayer(ni, nf, ks=ks, stride=stride, **_conv_args, **kwargs)
def DenseResBlock(
int, # 特征数量
nf:=NormType.Batch, # 归一化类型
-> SequentialEx:
) "Resnet block of `nf` features. `conv_kwargs` are passed to `conv_layer`."
return SequentialEx(ConvLayer(nf, nf, norm_type=norm_type, **kwargs),
=norm_type, **kwargs),
ConvLayer(nf, nf, norm_type=True)) MergeLayer(dense
def gan_critic(
int=3, # 评论者的输入通道数
n_channels:int=128, # 评论家特征数量
nf:int=3, # 判别器中ResNet块的数量
n_blocks:float=0.15 # 评论家中的丢弃量
p:-> nn.Sequential:
) "Critic to train a `GAN`."
= [
layers =4, stride=2),
_conv(n_channels, nf, ks/2),
DenseResBlock(nf, *= 2 # 经过密集区块
nf for i in range(n_blocks):
+= [
nn.Dropout2d(p),*2, ks=4, stride=2, self_attention=(i==0))]
_conv(nf, nf*= 2
nf += [
layers 1, ks=4, bias=False, padding=0, norm_type=NormType.Spectral, act_cls=None),
Flatten()]return nn.Sequential(*layers)
class GANLoss(GANModule):
"Wrapper around `crit_loss_func` and `gen_loss_func`"
def __init__(self,
callable, # 生成器损失函数
gen_loss_func:callable, # 批评损失函数
crit_loss_func:# GAN模型
def generator(self,
# 发电机输出
output, # 实像
):"Evaluate the `output` with the critic then uses `self.gen_loss_func` to evaluate how well the critic was fooled by `output`"
= self.gan_model.critic(output)
fake_pred self.gen_loss = self.gen_loss_func(fake_pred, output, target)
return self.gen_loss
def critic(self,
# 真实图像的评论预测
real_pred, input # 输入噪声向量以传递给生成器
):"Create some `fake_pred` with the generator from `input` and compare them to `real_pred` in `self.crit_loss_func`."
= self.gan_model.generator(input).requires_grad_(False)
fake = self.gan_model.critic(fake)
fake_pred self.crit_loss = self.crit_loss_func(real_pred, fake_pred)
return self.crit_loss
def gen_loss_func(fake_pred, output, target):
def crit_loss_func(real_pred, fake_pred):
class AdaptiveLoss(Module):
"Expand the `target` to match the `output` size before applying `crit`."
def __init__(self, crit:callable): self.crit = crit
def forward(self, output:Tensor, target:Tensor):
return self.crit(output, target[:,None].expand_as(output).float())
def accuracy_thresh_expand(y_pred:Tensor, y_true:Tensor, thresh:float=0.5, sigmoid:bool=True):
"Compute thresholded accuracy after expanding `y_true` to the size of `y_pred`."
if sigmoid: y_pred = y_pred.sigmoid()
return ((y_pred>thresh).byte()==y_true[:,None].expand_as(y_pred).byte()).float().mean()
def set_freeze_model(
# 模型冻结/解冻
m:nn.Module, bool # `requires_grad` 参数。设置为 `True` 表示冻结。
):for p in m.parameters(): p.requires_grad_(rg)
class GANTrainer(Callback):
"Callback to handle GAN Training."
= TrainEvalCallback
run_after def __init__(self,
bool=False, # 在计算损失时是否应将模型设置为评估模式
switch_eval:None|float=None, # 剪掉多少权重
clip:float=0.98, # 损失的指数加权平滑参数 `beta`
beta:bool=False, # 无论我们从生成器训练开始
gen_first:bool=True, # 是否在训练过程中展示生成的示例图像
store_attr(self.gen_loss,self.crit_loss = AvgSmoothLoss(beta=beta),AvgSmoothLoss(beta=beta)
def _set_trainable(self):
"Appropriately set the generator and critic into a trainable or loss evaluation mode based on `self.gen_mode`."
= self.generator if self.gen_mode else self.critic
train_model = self.generator if not self.gen_mode else self.critic
loss_model True)
set_freeze_model(train_model, False)
set_freeze_model(loss_model, if self.switch_eval:
def before_fit(self):
self.generator,self.critic = self.model.generator,self.model.critic
self.gen_mode = self.gen_first
self.crit_losses,self.gen_losses = [],[]
self.gen_loss.reset() ; self.crit_loss.reset()
#self.recorder.no_val = True
#self.recorder.add_metric_names(['gen_loss', 'disc_loss'])
#self.imgs, self.titles = [], []
def before_validate(self):
"Switch in generator mode for showing results."
def before_batch(self):
"Clamp the weights with `self.clip` if it's not None, set the correct input/target."
if self.training and self.clip is not None:
for p in self.critic.parameters(): p.data.clamp_(-self.clip, self.clip)
if not self.gen_mode:
self.learn.xb,self.learn.yb) = (self.yb,self.xb)
def after_batch(self):
"Record `last_loss` in the proper list."
if not self.training: return
if self.gen_mode:
self.last_gen = self.learn.to_detach(self.pred)
def before_epoch(self):
"Put the critic or the generator back to eval if necessary."
#def after_epoch(self):
# "Show a sample image."
# if not hasattr(self, 'last_gen') or not self.show_img: return
# data = self.learn.data
# img = self.last_gen[0]
# norm = getattr(data,'norm',False)
# if norm and norm.keywords.get('do_y',False): img = data.denorm(img)
# img = data.train_ds.y.reconstruct(img)
# self.imgs.append(img)
# self.titles.append(f'Epoch {epoch}')
# pbar.show_imgs(self.imgs, self.titles)
# return add_metrics(last_metrics, [getattr(self.smoothenerG,'smooth',None),getattr(self.smoothenerC,'smooth',None)])
def switch(self, gen_mode=None):
"Switch the model and loss function, if `gen_mode` is provided, in the desired mode."
self.gen_mode = (not self.gen_mode) if gen_mode is None else gen_mode
class FixedGANSwitcher(Callback):
"Switcher to do `n_crit` iterations of the critic then `n_gen` iterations of the generator."
= GANTrainer
run_after def __init__(self,
int=1, # 在切换到生成器之前,需要进行多少步的批评者训练
n_crit:int=1 # 在切换到判别器之前,生成器需要训练多少步
def before_train(self): self.n_c,self.n_g = 0,0
def after_batch(self):
"Switch the model if necessary."
if not self.training: return
if self.learn.gan_trainer.gen_mode:
self.n_g += 1
= self.n_gen,self.n_c,self.n_g
n_iter,n_in,n_out else:
self.n_c += 1
= self.n_crit,self.n_g,self.n_c
n_iter,n_in,n_out = n_iter if isinstance(n_iter, int) else n_iter(n_in)
target if target == n_out:
self.n_c,self.n_g = 0,0
class AdaptiveGANSwitcher(Callback):
"Switcher that goes back to generator/critic when the loss goes below `gen_thresh`/`crit_thresh`."
= GANTrainer
run_after def __init__(self,
None|float=None, # 发电机损耗阈值
gen_thresh:None|float=None # 评论家损失阈值
def after_batch(self):
"Switch the model if necessary."
if not self.training: return
if self.gan_trainer.gen_mode:
if self.gen_thresh is None or self.loss < self.gen_thresh: self.gan_trainer.switch()
if self.critic_thresh is None or self.loss < self.critic_thresh: self.gan_trainer.switch()
class GANDiscriminativeLR(Callback):
"`Callback` that handles multiplying the learning rate by `mult_lr` for the critic."
= GANTrainer
run_after def __init__(self, mult_lr=5.): self.mult_lr = mult_lr
def before_batch(self):
"Multiply the current lr if necessary."
if not self.learn.gan_trainer.gen_mode and self.training:
self.learn.opt.set_hyper('lr', self.learn.opt.hypers[0]['lr']*self.mult_lr)
def after_batch(self):
"Put the LR back to its value if necessary."
if not self.learn.gan_trainer.gen_mode: self.learn.opt.set_hyper('lr', self.learn.opt.hypers[0]['lr']/self.mult_lr)
GAN 数据
class InvisibleTensor(TensorBase):
"TensorBase but show method does nothing"
def show(self, ctx=None, **kwargs): return ctx
def generate_noise(
# 虚拟参数,以便与 `DataBlock` 兼容
fn, =100 # 返回噪声向量的尺寸
size-> InvisibleTensor:
) "Generate noise vector."
return cast(torch.randn(size), InvisibleTensor)
我们使用 generate_noise
def show_batch(x:InvisibleTensor, y:TensorImage, samples, ctxs=None, max_n=10, nrows=None, ncols=None, figsize=None, **kwargs):
if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)
= show_batch[object](x, y, samples, ctxs=ctxs, max_n=max_n, **kwargs)
ctxs return ctxs
def show_results(x:InvisibleTensor, y:TensorImage, samples, outs, ctxs=None, max_n=10, nrows=None, ncols=None, figsize=None, **kwargs):
if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)
= [b.show(ctx=c, **kwargs) for b,c,_ in zip(outs.itemgot(0),ctxs,range(max_n))]
ctxs return ctxs
= 128
bs = 64 size
= DataBlock(blocks = (TransformBlock, ImageBlock),
dblock = generate_noise,
get_x = get_image_files,
get_items = IndexSplitter([]),
splitter =Resize(size, method=ResizeMethod.Crop),
item_tfms= Normalize.from_stats(torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5]))) batch_tfms
= untar_data(URLs.LSUN_BEDROOMS) path
= dblock.dataloaders(path, path=path, bs=bs) dls
=16) dls.show_batch(max_n
GAN 学习器
def gan_loss_from_func(
callable, # 生成器的损失函数。评估生成器输出图像与目标真实图像。
loss_gen:callable, # 用于评价判别器的损失函数。评估真实图像和生成图像的预测结果。
loss_crit:None|MutableSequence|tuple=None # 生成器和判别器损失函数的权重
):"Define loss functions for a GAN from `loss_gen` and `loss_crit`."
def _loss_G(fake_pred, output, target, weights_gen=weights_gen):
= fake_pred.new_ones(fake_pred.shape[0])
ones = ifnone(weights_gen, (1.,1.))
weights_gen return weights_gen[0] * loss_crit(fake_pred, ones) + weights_gen[1] * loss_gen(output, target)
def _loss_C(real_pred, fake_pred):
= real_pred.new_ones (real_pred.shape[0])
ones = fake_pred.new_zeros(fake_pred.shape[0])
zeros return (loss_crit(real_pred, ones) + loss_crit(fake_pred, zeros)) / 2
return _loss_G, _loss_C
def _tk_mean(fake_pred, output, target): return fake_pred.mean()
def _tk_diff(real_pred, fake_pred): return real_pred.mean() - fake_pred.mean()
class GANLearner(Learner):
"A `Learner` suitable for GANs."
def __init__(self,
# 用于GAN数据的DataLoaders对象
dls:DataLoaders, # 发电机模型
generator:nn.Module, # 批评模型
critic:nn.Module, callable, # 生成器损失函数
gen_loss_func:callable, # 批评损失函数
crit_loss_func:|None=None, # 用于在生成器和判别器训练之间切换的回调函数,默认为 `FixedGANSwitcher`。
switcher:Callbackbool=False, # 无论我们从生成器训练开始
gen_first:bool=True, # 在计算损失时是否应将模型设置为评估模式
switch_eval:bool=True, # 是否在训练过程中展示生成的示例图像
show_img:None|float=None, # 剪裁权重多少
clip:|None|MutableSequence=None, # 其他回调函数
cbs:CallbackNone|MutableSequence|callable=None, # 指标
):= GANModule(generator, critic)
gan = GANLoss(gen_loss_func, crit_loss_func, gan)
loss_func if switcher is None: switcher = FixedGANSwitcher()
= GANTrainer(clip=clip, switch_eval=switch_eval, gen_first=gen_first, show_img=show_img)
trainer = L(cbs) + L(trainer, switcher)
cbs = L(metrics) + L(*LossMetrics('gen_loss,crit_loss'))
metrics super().__init__(dls, gan, loss_func=loss_func, cbs=cbs, metrics=metrics, **kwargs)
def from_learners(cls,
# 一个包含生成器的`Learner`对象
gen_learn:Learner, # 一个包含评价器的`学习者`对象
crit_learn:Learner, |None=None, # 用于在生成器和判别器训练之间切换的回调函数,默认为 `FixedGANSwitcher`。
switcher:CallbackNone|MutableSequence|tuple=None, # 生成器和判别器损失函数的权重
):"Create a GAN from `learn_gen` and `learn_crit`."
= gan_loss_from_func(gen_learn.loss_func, crit_learn.loss_func, weights_gen=weights_gen)
losses return cls(gen_learn.dls, gen_learn.model, crit_learn.model, *losses, switcher=switcher, **kwargs)
def wgan(cls,
# 用于GAN数据的DataLoaders对象
dls:DataLoaders, # 发电机模型
generator:nn.Module, # 批评模型
critic:nn.Module, |None=None, # 用于在生成器和判别器训练之间切换的回调函数,默认为 `FixedGANSwitcher(n_crit=5, n_gen=1)`。
switcher:CallbackNone|float=0.01, # 剪裁权重多少
clip:bool=False, # 在计算损失时是否应将模型设置为评估模式
):"Create a [WGAN](https://arxiv.org/abs/1701.07875) from `dls`, `generator` and `critic`."
if switcher is None: switcher = FixedGANSwitcher(n_crit=5, n_gen=1)
return cls(dls, generator, critic, _tk_mean, _tk_diff, switcher=switcher, clip=clip, switch_eval=switch_eval, **kwargs)
= delegates(to=GANLearner.__init__)(GANLearner.from_learners)
GANLearner.from_learners = delegates(to=GANLearner.__init__)(GANLearner.wgan) GANLearner.wgan
from fastai.callback.all import *
= basic_generator(64, n_channels=3, n_extra_layers=1)
generator = basic_critic (64, n_channels=3, n_extra_layers=1, act_cls=partial(nn.LeakyReLU, negative_slope=0.2)) critic
= GANLearner.wgan(dls, generator, critic, opt_func = RMSProp) learn
learn.recorder.train_metrics=False learn.recorder.valid_metrics
1, 2e-4, wd=0.) learn.fit(
=9, ds_idx=0) learn.show_results(max_n
