码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • mmsegmentation 添加L1Loss


    mmseg/models/losses/模块中添加L1Loss定义:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from ..builder import LOSSES
    
    
    @LOSSES.register_module()
    class L1Loss(nn.Module):
        # TODO: weight
        def __init__(self, loss_name='loss_l1', **kwargs):
            super(L1Loss, self).__init__()
            self._loss_name = loss_name
    
        def forward(self, pred, target, weight=None, ignore_index=None):
       		# pred: (n,c,h,w)   target: (n,h,w)
            classes = pred.shape[1]
            size = list(target.shape)
            size.append(classes)  # (n,h,w,c)
            target_one_hot = target.view(-1)  # (n*h*w)
            ones = torch.sparse.torch.eye(classes).to(target_one_hot.device)
            ones = ones.index_select(0, target_one_hot)  # (n*h*w, classes)
            ones = ones.view(*size)  # (n,h,w,c)
            target_one_hot = ones.permute(0, 3, 1, 2)  # (n,c,h,w)
            loss = nn.L1Loss()(pred, target_one_hot)
            return loss
    
    	@property
        def loss_name(self):
            """Loss Name.
    
            This function must be implemented and will return the name of this
            loss function. This name will be used to combine different loss items
            by simple sum operation. In addition, if you want this loss item to be
            included into the backward graph, `loss_` must be the prefix of the
            name.
    
            Returns:
                str: The name of this loss item.
            """
            return self._loss_name
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40

    注意必须要有loss_name方法,并且返回的loss_name需要以loss_作为前缀。

    传入的pred和target的shape不一致,需要转为一致才可以直接调用nn.L1Loss()方法。
    pred.shape: (n,c,h,w)
    target.shape: (n,h,w)
    所以需要将target转one-hot。转one-hot方法:index_select。
    (n,h,w) => (n,h,w,c) => (n,c,h,w)

    拓展阅读:Pytorch中,将label变成one hot编码的两种方式

  • 相关阅读:
    星环科技重磅推出数据要素流通平台Transwarp Navier,助力企业实现隐私保护下的数据安全流通与协作
    存储过程和函数
    WuThreat身份安全云-TVD每日漏洞情报-2022-12-07
    【Python零基础入门篇 · 6】:Python中的注释、字符串的常见操作、对象的布尔值
    【SQL刷题】DAY18----SQL汇总数据专项练习
    uni-app小零碎(包括封装网络请求)
    watch和watchEffect之间的小关系
    [附源码]Python计算机毕业设计超市团购平台
    训练一个图像分类器demo in PyTorch【学习笔记】
    ClickHouse的 MaterializeMySQL引擎
  • 原文地址:https://blog.csdn.net/qq_39735236/article/details/127806133
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号