这里是对pytorch的代码进行的解析:GitHub - facebookresearch/pytorch_GAN_zoo: A mix of GAN implementations including progressive growing
解析的过程我采用:自顶向下、逐步求精的方法,也就是从全局到局部
整体的一个运行关系:
train.py 是整个项目的起点:加载什么样的配置(parser.add_argument + config_256_ChEMBL.json)、加载数据集、开始train...
通过train.py中的“getTrainer”函数加载 progressive_gan_trainer.py 和progressive_gan_trainer.py中的 ProgressiveGANTrainer class
G 和 D的构造位置在:models --> network --> progressive_conv_net.py 中:
作者采用progressive growing的训练方式,先训一个小分辨率的图像生成,训好了之后再逐步过渡到更高分辨率的图像。然后稳定训练当前分辨率,再逐步过渡到下一个更高的分辨率。
如上图所示。更具体点来说,当处于fade in(或者说progressive growing)阶段的时候,上一分辨率(4*4)会通过resize+conv操作得到跟下一分辨率(8*8)同样大小的输出,然后两部分做加权,再通过to_rgb操作得到最终的输出。这样做的一个好处是它可以充分利用上个分辨率训练的结果,通过缓慢的过渡(w逐渐增大),使得训练生成下一分辨率的网络更加稳定。
下图是Discriminator的growing,它跟Generator的类似,差别在于一个是上采样,一个是下采样。这里就不再赘述。