• GAN里面什么时候用detach的说明


    生成对抗网络(GAN)中,生成器(G)和判别器(D)通常是两个独立的神经网络,它们之间会有梯度传播的互动。下面是一个简单的GAN的PyTorch实现,用于生成一维数据,以展示何时应该使用detach()。

    import torch
    import torch.nn as nn
    import torch.optim as optim
    
    # 生成器
    class Generator(nn.Module):
        def __init__(self):
            super(Generator, self).__init__()
            self.model = nn.Sequential(
                nn.Linear(10, 50),
                nn.ReLU(),
                nn.Linear(50, 1)
            )
        
        def forward(self, x):
            return self.model(x)
    
    # 判别器
    class Discriminator(nn.Module):
        def __init__(self):
            super(Discriminator, self).__init__()
            self.model = nn.Sequential(
                nn.Linear(1, 50),
                nn.ReLU(),
                nn.Linear(50, 1),
                nn.Sigmoid()
            )
        
        def forward(self, x):
            return self.model(x)
    
    # 实例化生成器和判别器
    G = Generator()
    D = Discriminator()
    
    # 定义优化器和损失函数
    optimizer_G = optim.Adam(G.parameters(), lr=0.001)
    optimizer_D = optim.Adam(D.parameters(), lr=0.001)
    loss_func = nn.BCELoss()
    
    # 训练循环
    for epoch in range(1000):
        # 训练判别器
        D.zero_grad()
        real_data = torch.randn(100, 1)  # 真实数据
        real_labels = torch.ones(100, 1) # 真实标签
        fake_data = G(torch.randn(100, 10)).detach() # 使用detach(), 因为我们不想在这一步更新生成器
        fake_labels = torch.zeros(100, 1) # 假的标签
    
        real_loss = loss_func(D(real_data), real_labels)
    	# real_loss = loss_func(D(real_data.detach), real_labels)
        fake_loss = loss_func(D(fake_data), fake_labels)
        d_loss = real_loss + fake_loss
        d_loss.backward()
        optimizer_D.step()
    
        # 训练生成器
        G.zero_grad()
        noise_data = torch.randn(100, 10) # 噪声数据
        fake_data = G(noise_data) # 没有使用detach(), 因为我们想在这一步更新生成器
        g_loss = loss_func(D(fake_data), torch.ones(100, 1))
        g_loss.backward()
        optimizer_G.step()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64

    在这个例子中:

    1. 当训练判别器(D)时,我们使用了detach()来中断梯度传播到生成器(G)。这是因为在这一步中,我们仅关心优化判别器,而不希望更新生成器的参数。
    2. 当训练生成器(G)时,我们没有使用detach(),因为我们需要通过反向传播的梯度来更新生成器的参数。

    注意:在训练判别器时,不使用real_loss = loss_func(D(real_data.detach), real_labels), 也就是这里不需要对real_data进行detach操作。

    而且即使对real_data进行.detach()操作实际上应该不会有明显影响,原因在于real_data并不是通过模型参数生成的,也不是一个需要优化的变量。.detach()方法主要用于将一个张量从当前计算图中分离出来,阻止反向传播过程中对其计算梯度。但在本例中,real_data本身就没有与需要优化的模型参数有直接关系,也不是由其他需要优化的变量通过一些运算得到的。

    注意: 在训练判别器时,使用fake_data = G(torch.randn(100, 10)).detach(), 注意是因为这个fake_data是由生成器G生成的, 为了保证分开训练判别器和生成器,即在训练判别器的时候,不对生成器的参数进行更新,这里就要把G生成的数据进行detach操作

    在训练生成器时, 也用到了判别器,用判别器去判别生成器生成的内容,希望判别器能把G生成的内容当做真的,这样就说明G的生成的内容可以以假乱真

    fake_data = G(noise_data) # 没有使用detach(), 因为我们想在这一步更新生成器
    g_loss = loss_func(D(fake_data), torch.ones(100, 1))
    g_loss.backward()
    optimizer_G.step()
    
    • 1
    • 2
    • 3
    • 4

    上面没有对传进D的fake_data进行detach,是因为下面的代码只有g_loss_backward(),也就是只对G进行参数更新,当然这里也不能对fake_data进行detach,如果detach了,就无法更新G的参数了

  • 相关阅读:
    华为机试真题 C++ 实现【连续字母长度】
    Spark---持久化,共享变量和RDD之间的依赖关系详解
    open-set recognition(OSR)开集识别
    C语言-联合体与枚举类型
    [网鼎杯 2020 朱雀组]Nmap wp
    java毕业设计软件S2SH人力资源管理系统|人事薪资招聘oa人力请假考勤工资[包运行成功]
    软路由和硬路由的区别是什么,性价比与可玩性分析
    SpringMVC之国际化&上传&下载
    记一次Nacos线程数飙升排查
    草稿草稿草稿,python 和VBA的差别对比汇总 收集ing
  • 原文地址:https://blog.csdn.net/weixin_43845922/article/details/133163478