• 【GAN对抗性损失函数】以CycleGAN和PIX2PIX算法的对抗性损失的代码为例进行讲解


    一、代码

    class GANLoss(nn.Module):
        """Define different GAN objectives.
        The GANLoss class abstracts away the need to create the target label tensor
        that has the same size as the input.
        """
        def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
            """ Initialize the GANLoss class.
            Parameters:
                gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
                target_real_label (bool) - - label for a real image
                target_fake_label (bool) - - label of a fake image
            Note: Do not use sigmoid as the last layer of Discriminator.
            LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
            """
            super(GANLoss, self).__init__()
            self.register_buffer('real_label', torch.tensor(target_real_label))
            self.register_buffer('fake_label', torch.tensor(target_fake_label))
            self.gan_mode = gan_mode
            if gan_mode == 'lsgan':
                self.loss = nn.MSELoss()
            elif gan_mode == 'RidgeRegressionaLoss':
                self.loss = RidgeLoss1(alpha=0.1)
            elif gan_mode == 'vanilla':
                self.loss = nn.BCEWithLogitsLoss()
            elif gan_mode in ['wgangp']:
                self.loss = None
            else:
                raise NotImplementedError('gan mode %s not implemented' % gan_mode)
        def get_target_tensor(self, prediction, target_is_real):
            """Create label tensors with the same size as the input.
            Parameters:
                prediction (tensor) - - tpyically the prediction from a discriminator
                target_is_real (bool) - - if the ground truth label is for real images or fake images
            Returns:
                A label tensor filled with ground truth label, and with the size of the input
            """
            if target_is_real:
                target_tensor = self.real_label
            else:
                target_tensor = self.fake_label
            return target_tensor.expand_as(prediction)
    
        def __call__(self, prediction, target_is_real):
            """Calculate loss given Discriminator's output and grount truth labels.
    
            Parameters:
                prediction (tensor) - - tpyically the prediction output from a discriminator
                target_is_real (bool) - - if the ground truth label is for real images or fake images
    
            Returns:
                the calculated loss.
            """
            if self.gan_mode in ['lsgan', 'vanilla','RidgeRegressionaLoss']:
                target_tensor = self.get_target_tensor(prediction, target_is_real)
                loss = self.loss(prediction, target_tensor)
            elif self.gan_mode == 'wgangp':
                if target_is_real:
                    loss = -prediction.mean()
                else:
                    loss = prediction.mean()
            return loss
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61

    二、讲解

    target_tensor.expand_as(prediction)的意思是将target_tensor张量的尺寸扩展为与prediction张量相同的尺寸。

    生成对抗网络(GAN)中,判别器的输出通常是一个张量,表示样本为真实样本的概率或得分。为了计算损失,需要创建与判别器输出相同尺寸的目标标签张量。target_tensorget_target_tensor方法中获得,表示目标标签,可以是真实样本标签或虚假样本标签。为了与判别器的输出张量进行元素级别的比较,需要将目标标签张量的尺寸扩展为与判别器输出相同的形状。

    expand_as(prediction)方法是一个张量的方法,它返回一个尺寸与prediction张量相同的新张量,其中新张量的元素以target_tensor的元素进行填充或重复,以便与prediction进行逐元素比较。

    通过将目标标签张量的尺寸扩展为与判别器输出相同的尺寸,可以确保在计算损失时每个生成样本或真实样本的标签都与对应的判别器输出进行比较。

  • 相关阅读:
    CSDN竞赛第四期季军 解题思路及参赛经历分享
    解决Redis缓存穿透(缓存空对象、布隆过滤器)
    【LeetCode每日一题合集】2023.8.28-2023.9.3(到家的最少跳跃次数)
    知道策略模式!但不会在项目里使用?
    llvm源码windows编译
    在项目中单元测试是用来做什么的?
    window10彻底关闭系统管理员控制(所有软件以管理员身份运行)
    网络编程套接字
    ros develop 相关
    关于容器镜像那些事
  • 原文地址:https://blog.csdn.net/lingchen1906/article/details/133469463