• 对抗生成网络(GAN)中的损失函数


    目录

    GAN的训练过程:

    L1和L2损失函数的区别

    基础概念

    相同点

    差异


    GAN的训练过程:

    1、先定义一个标签:real = 1,fake = 0。当然这两个值的维度是按照数据的输出来看的。再定义了两个优化器。用于生成器和判别器。

    2、随机生成一个噪声z。将z作为生成器的输入,输出gen_imgs(假样本)。

    3、计算生成器的损失

    1. 定义:生成器的损失为g_loss。损失函数为adverisal_loss()。判别器为discriminator()。
    2. g_loss = adverisal_loss(discriminator(gen_imgs), real)
    3. g_loss.backward()
    4. optimizer_G.step()

    可以看出来,g_loss是根据一个输出(将生成的样本作为输入的判别器的输出)与real的一个损失。

    1)discriminator(gen_imgs) 的输出是个什么?
    既然是判别器,意思就是判别gen_imgs是不是真样本。如果是用softmax输出,是一个概率,为真样本的概率。

    2)g_loss = adverisal_loss(discriminator(gen_imgs), real)
    计算g_loss就是判别器的输出与real的差距,让g_loss越来越小,就是让gen_imgs作为判别器的输出的概率更接近valid。就是让gen_imgs更像真样本。

    3)要注意的是,这个g_loss用于去更新了生成器的权重。这个时候,判别器的权重并没有被更新。

    4、分别把假样本和真样本都送入到判别器。

    1. real_loss = adverisal_loss(discriminator(real_imgs), real)
    2. fake_loss = adverisal_loss(discriminator(gen_imgs.detach()), fake)
    3. d_loss = (real_loss + fake_loss) / 2
    4. d_loss.backward()
    5. optimizer_D.step()

    real_loss是判别器去判别真样本的输出,让这个输出更接近与real。

    fake_loss是判别器去判别假样本的输出,让这个输出更接近与fake。

    d_loss是前两者的平均。

    损失函数向后传播,就是为了让d_loss ---> 0。也就是让:

    real_loss ---> 0 ===> 让判别器的输出(真样本概率)接近 real

    fake_loss ---> 0 ===> 让判别器的输出(假样本概率)接近 fake

    也就是说,让判别器按照真假样本的类别,分别按照不同的要求去更新参数。

    5、损失函数的走向?

    g_loss 越小,说明生成器生产的假样本作为判别器的输入的输出(概率)越接近real,就是生成的假样本越像真样本。

    d_loss越小,说明判别器越能够将识别出真样本和假样本。

    所以,最后是要让g_loss更小,d_loss更接近0.5。以至于d_loss最后为0.5的时候,达到最好的效果。这个0.5的意思就是:判别器将真样本全部识别正确,所以real_loss=0。把所有的生成的假样本识别错误(生成的样本很真),此时fake_loss = 1。最后的d_loss = 1/2。

    补充:

    L1和L2损失函数的区别

    基础概念

        L1损失函数又称为MAE(mean abs error),即平均绝对误差,也就是预测值和真实值之间差值的绝对值。
        L2损失函数又称为MSE(mean square error),即平均平方误差,也就是预测值和真实值之间差值的平方。

    相同点

        因为计算的方式类似,只有一个平方的差异,因此使用的场合都很相近,通常用于回归任务中。

    差异

        1)L2没有L1鲁棒,直观来说,L2会将误差平方,如果误差大于1,则误差会被放大很多,因此模型会对异常样本更敏感,这样会牺牲许多正常的样本。当训练集中含有更多异常值的时候,L1会更有效。
        2)如果是图像重建任务,如超分辨率、深度估计、视频插帧等,L2会更加有效,这是由任务特性决定了,图像重建任务中通常预测值和真实值之间的差异不大,因此需要用L2损失来放大差异,进而指导模型的优化。
        3)L1的问题在于它的梯度在极值点会发生跃变,并且很小的差异也会带来很大的梯度,不利于学习,因此在使用时通常会设定学习率衰减策略。而L2作为损失函数的时候本身由于其函数的特性,自身就会对梯度进行缩放,因此有的任务在使用L2时甚至不会调整学习率,不过随着现在的行业认知,学习率衰减策略在很多场景中依然是获得更优模型的手段。
     

  • 相关阅读:
    POJ 2739 Sum of Consecutive Prime Numbers 尺取法
    Golang字符串和数组的相互转换
    数据结构之栈和队列以及如何封装栈和队列,栈和队列的实例(进制转换和击鼓传花)
    Leetcode 9.11每日一题 630. 课程表 III
    访问者模式
    菜鸟网络一面(超详细)
    理解 ROC 和 PRC
    交换机和路由器技术-32-命名ACL
    JIRA项目工具及日常查询
    PgSQL-执行器机制-Unique算子
  • 原文地址:https://blog.csdn.net/L888666Q/article/details/127793598