• InfoGAN原理PyTorch实现Debug记录


    CGAN从无监督GAN改进成有监督的GAN

    GAN的基本原理输入是随机噪声,无法控制输出和输入之间的对应关系,也无法控制输出的模式,CGAN全称是条件GAN(Conditional GAN)改进基本的GAN解决了这个问题,CGAN和基本的GAN不同的地方是:
    参考下面的链接
    https://www.jianshu.com/p/39c57e9a6630
    这里面介绍了实现CGAN有三种形式,从网络实现上的三种形式,没有讲解怎样优化目标函数

    CGAN的一个问题是输入的有监督标签是离散型输入,如果输入中还有连续型输入,也就是C这个条件是个连续型的,那么将要继续参考InfoGAN

    InfoGAN

    参考下面的链接,非常详细的讲解了InfoGAN的原理、网络结构的实现、损失函数怎样求解
    https://www.jianshu.com/p/fa892c81df60
    InfoGAN的Info部分和判别器D共用了前面的网络,那么PyTorch怎么实现共用网咯呢?
    参考下面的PyTorch实现
    https://mp.weixin.qq.com/s?__biz=MzI3MzkyMzE5Mw==&mid=2247485031&idx=1&sn=e6ccbc33639462d59ee56923c59173b6&chksm=eb1aab71dc6d2267cc52bf769106067c53c867ad6a02063791674937857fb86da36ecd6cbfd9&token=1864035800&lang=zh_CN#rd
    在这里插入图片描述
    原来PyTorch定义判别器类的时候可以分成三个网络,分别是主网络、D网络和C网络、L网络,D网络和C网络和L网络公用主网络,这个例子中的InfoGAN得输入有随机噪声、离散输入(C部分)、连续输入(L部分),forward中先用主网络处理x,之后返回D网络、C网络和L网络
    不得不说,这种写法很有趣啊

    PyTorch实现Debug记录

    我自己实现了InfoGAN网络,运行程序后接二连三出现了很多错误

    Bug(1):

    RuntimeError: one of the variables needed for gradient computation has been modified by an inplace
    参考链接:
    https://blog.csdn.net/qq_32953463/article/details/115728762
    出现这个错误的原因是Pytorch的版本问题,我的Pytorch是1.11.0版本,如果Pytorch版本低于1.4不会有这个问题,链接中提供了一种不需要重新安装Pytorch的办法,backward()放在一起,step()放在一起,zero_grad()不需要放在一起,如下截图所示,

    不得不说,这么神奇,真的解决了

                real_out = netD(real_img).mean()
                fake_out = netD(fake_img).mean()
                d_loss = 1 - real_out + fake_out
                netD.zero_grad()
                g_loss = generator_criterion(fake_out, fake_img, real_img)
                netG.zero_grad()
    
                d_loss.backward(retain_graph=True)
                g_loss.backward()
    
                optimizerD.step()
                optimizerG.step()
    
                fake_img = netG(z)
                fake_out = netD(fake_img).mean()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    Bug(2):RuntimeError: Trying to backward through the graph a second time but the buffers have already been f

    或者说是pytorch中的retain_graph=True的作用

    参考链接:https://blog.csdn.net/qq_39861441/article/details/104129368
    总的来说进行一次backward之后,各个节点的值会清除,这样进行第二次backward会报错,如果加上retain_graph==True后,可以再来一次backward。

    在这里插入图片描述

    上面的示例代码中前两个网络D和G在backward的时候使用了retain_graph=True的参数,最后一个网络没有使用此参数,此参数的默认值是False

    如果想了解底层的原理,建议阅读下面的链接,里面的图解非常的有趣
    https://blog.csdn.net/SY_qqq/article/details/107384161

    Bug(3):

    RuntimeError: Found dtype Long but expected Float
    这个错误来源于torch需要float类型,但是数据中是int类型或者long类型,解决方法是debug一个一个看变量中哪里出现了int或者long类型,假设variable是int或者long类型的变量,将它转换成float类型

    variable = variable.to(torch.float32)
    
    • 1
  • 相关阅读:
    无人机/FPV穿越机的遥控器/接收机等配件厂商
    1999-2021地级市GDP及一二三产业GDP数据
    2、使用RedisTemplate实现基本数据类型增删改查
    数的连接|NOIP1998 T2|贪心算法
    HW-小记(二)
    FreeRTOS教程1 基础知识
    SpringMVC之JSON数据返回与异常处理机制---全方面讲解
    字节跳动二面(消息队列+分布式+CAS+ThreadLocal)
    旅行季《乡村振兴战略下传统村落文化旅游设计》许少辉八一新著作想象和世界一样宽广
    数据结构与算法——排序算法
  • 原文地址:https://blog.csdn.net/ningmengzhihe/article/details/125500558