• 基于深度学习的图像去雨去雾


    基于深度学习的图像去雨去雾


    文末附有源码下载地址
    b站视频地址: https://www.bilibili.com/video/BV1Jr421p7cT/

    基于深度学习的图像去雨去雾,使用的网络为unet,
    网络代码:

    import torch
    import torch.nn as nn
    from torchsummary import summary
    from torchvision import models
    from torchvision.models.feature_extraction import create_feature_extractor
    import torch.nn.functional as F
    from torchstat import stat
    
    class Resnet18(nn.Module):
        def __init__(self):
            super(Resnet18, self).__init__()
            self.resnet = models.resnet18(pretrained=False)
            # self.resnet = create_feature_extractor(self.resnet, {'relu': 'feat320', 'layer1': 'feat160', 'layer2': 'feat80',
            #                                                'layer3': 'feat40'})
    
        def forward(self,x):
            for name,m in self.resnet._modules.items():
    
                x=m(x)
                if name=='relu':
                    x1=x
                elif name=='layer1':
                    x2=x
                elif name=='layer2':
                    x3=x
                elif name=='layer3':
                    x4=x
                    break
            # x=self.resnet(x)
            return x1,x2,x3,x4
    class Linears(nn.Module):
        def __init__(self,a,b):
            super(Linears, self).__init__()
            self.linear1=nn.Linear(a,b)
            self.relu1=nn.LeakyReLU()
            self.linear2 = nn.Linear(b, a)
            self.sigmoid=nn.Sigmoid()
        def forward(self,x):
            x=self.linear1(x)
            x=self.relu1(x)
            x=self.linear2(x)
            x=self.sigmoid(x)
            return x
    class DenseNetBlock(nn.Module):
        def __init__(self,inplanes=1,planes=1,stride=1):
            super(DenseNetBlock,self).__init__()
            self.conv1=nn.Conv2d(inplanes,planes,3,stride,1)
            self.bn1 = nn.BatchNorm2d(planes)
            self.relu1=nn.LeakyReLU()
    
            self.conv2 = nn.Conv2d(inplanes, planes, 3,stride,1)
            self.bn2 = nn.BatchNorm2d(planes)
            self.relu2 = nn.LeakyReLU()
    
            self.conv3 = nn.Conv2d(inplanes, planes, 3,stride,1)
            self.bn3 = nn.BatchNorm2d(planes)
            self.relu3 = nn.LeakyReLU()
        def forward(self,x):
            ins=x
            x=self.conv1(x)
            x=self.bn1(x)
            x=self.relu1(x)
            x = self.conv2(x)
            x = self.bn2(x)
            x = self.relu2(x)
            x=x+ins
    
            x2=self.conv3(x)
            x2 = self.bn3(x2)
            x2=self.relu3(x2)
    
            out=ins+x+x2
            return out
    class SEnet(nn.Module):
        def __init__(self,chs,reduction=4):
            super(SEnet,self).__init__()
            self.average_pooling = nn.AdaptiveAvgPool2d(output_size=(1, 1))
            self.fc = nn.Sequential(
                # First reduce dimension, then raise dimension.
                # Add nonlinear processing to fit the correlation between channels
                nn.Linear(chs, chs // reduction),
                nn.LeakyReLU(inplace=True),
                nn.Linear(chs // reduction, chs)
            )
            self.activation = nn.Sigmoid()
        def forward(self,x):
            ins=x
            batch_size, chs, h, w = x.shape
            x=self.average_pooling(x)
            x = x.view(batch_size, chs)
            x=self.fc(x)
            x = x.view(batch_size,chs,1,1)
            return x*ins
    class UAFM(nn.Module):
        def __init__(self):
            super(UAFM, self).__init__()
            # self.meanPool_C=torch.max()
    
            self.attention=nn.Sequential(
                nn.Conv2d(4, 8, 3, 1,1),
                nn.LeakyReLU(),
                nn.Conv2d(8, 1, 1, 1),
                nn.Sigmoid()
            )
    
    
        def forward(self,x1,x2):
            x1_mean_pool=torch.mean(x1,dim=1)
            x1_max_pool,_=torch.max(x1,dim=1)
            x2_mean_pool = torch.mean(x2, dim=1)
            x2_max_pool,_ = torch.max(x2, dim=1)
    
            x1_mean_pool=torch.unsqueeze(x1_mean_pool,dim=1)
            x1_max_pool=torch.unsqueeze(x1_max_pool,dim=1)
            x2_mean_pool=torch.unsqueeze(x2_mean_pool,dim=1)
            x2_max_pool=torch.unsqueeze(x2_max_pool,dim=1)
    
            cat=torch.cat((x1_mean_pool,x1_max_pool,x2_mean_pool,x2_max_pool),dim=1)
            a=self.attention(cat)
            out=x1*a+x2*(1-a)
            return out
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.resnet18=Resnet18()
            self.SENet=SEnet(chs=256)
            self.UAFM=UAFM()
            self.DenseNet1=DenseNetBlock(inplanes=256,planes=256)
            self.transConv1=nn.ConvTranspose2d(256,128,3,2,1,output_padding=1)
    
            self.DenseNet2 = DenseNetBlock(inplanes=128, planes=128)
            self.transConv2 = nn.ConvTranspose2d(128, 64, 3, 2, 1, output_padding=1)
    
            self.DenseNet3 = DenseNetBlock(inplanes=64, planes=64)
            self.transConv3 = nn.ConvTranspose2d(64, 64, 3, 2, 1, output_padding=1)
    
            self.transConv4 = nn.ConvTranspose2d(64, 32, 3, 2, 1, output_padding=1)
            self.DenseNet4=DenseNetBlock(inplanes=32,planes=32)
            self.out=nn.Sequential(
                nn.Conv2d(32,3,1,1),
                nn.Sigmoid()
            )
    
        def forward(self,x):
            """
            下采样部分
            """
            x1,x2,x3,x4=self.resnet18(x)
            # feat320=features['feat320']
            # feat160=features['feat160']
            # feat80=features['feat80']
            # feat40=features['feat40']
            feat320=x1
            feat160=x2
            feat80=x3
            feat40=x4
            """
            上采样部分
            """
            x=self.SENet(feat40)
            x=self.DenseNet1(x)
            x=self.transConv1(x)
            x=self.UAFM(x,feat80)
    
            x=self.DenseNet2(x)
            x=self.transConv2(x)
            x=self.UAFM(x,feat160)
    
            x = self.DenseNet3(x)
            x = self.transConv3(x)
            x = self.UAFM(x, feat320)
    
            x=self.transConv4(x)
            x=self.DenseNet4(x)
            out=self.out(x)
    
            # out=torch.concat((out,out,out),dim=1)*255.
    
            return out
    
        def freeze_backbone(self):
            for param in self.resnet18.parameters():
                param.requires_grad = False
    
        def unfreeze_backbone(self):
            for param in self.resnet18.parameters():
                param.requires_grad = True
    
    
    if __name__ == '__main__':
    
        net=Net()
        print(net)
        # stat(net,(3,640,640))
    
        summary(net,input_size=(3,512,512),device='cpu')
    
        aa=torch.ones((6,3,512,512))
        out=net(aa)
        print(out.shape)
        # ii=torch.zeros((1,3,640,640))
        # outs=net(ii)
        # print(outs.shape)
    
    
    
    
    
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177
    • 178
    • 179
    • 180
    • 181
    • 182
    • 183
    • 184
    • 185
    • 186
    • 187
    • 188
    • 189
    • 190
    • 191
    • 192
    • 193
    • 194
    • 195
    • 196
    • 197
    • 198
    • 199
    • 200
    • 201
    • 202
    • 203
    • 204
    • 205
    • 206
    • 207
    • 208
    • 209
    • 210

    主题界面显示及代码:
    在这里插入图片描述

    from PyQt5.QtGui import *
    from PyQt5.QtWidgets import *
    from untitled import Ui_Form
    import sys
    import cv2 as cv
    from PyQt5.QtCore import QCoreApplication
    import numpy as np
    from PyQt5 import QtCore,QtGui
    from PIL import Image
    from predict import *
    
    class My(QMainWindow,Ui_Form):
        def __init__(self):
            super(My,self).__init__()
            self.setupUi(self)
            self.setWindowTitle('图像去雨去雾')
            self.setIcon()
            self.pushButton.clicked.connect(self.pic)
            self.pushButton_2.clicked.connect(self.pre)
            self.pushButton_3.clicked.connect(self.pre2)
        def setIcon(self):
           palette1 = QPalette()
           # palette1.setColor(self.backgroundRole(), QColor(192,253,123))   # 设置背景颜色
           palette1.setBrush(self.backgroundRole(), QBrush(QPixmap('back.png')))  # 设置背景图片
           self.setPalette(palette1)
        def pre(self):
            out=pre(self.img,0)
            out=self.cv_qt(out)
            self.label_2.setPixmap(QPixmap.fromImage(out).scaled(self.label.width(),self.label.height(),QtCore.Qt.KeepAspectRatio))
        def pre2(self):
            out=pre(self.img,1)
            out=self.cv_qt(out)
            self.label_2.setPixmap(QPixmap.fromImage(out).scaled(self.label.width(),self.label.height(),QtCore.Qt.KeepAspectRatio))
    
        def pic(self):
            imgName, imgType = QFileDialog.getOpenFileName(self,
                                                           "打开图片",
                                                           "",
                                                           " *.png;;*.jpg;;*.jpeg;;*.bmp;;All Files (*)")
            #KeepAspectRatio
            png = QtGui.QPixmap(imgName).scaled(self.label.width(),self.label.height(),QtCore.Qt.KeepAspectRatio)  # 适应设计label时的大小
            self.label.setPixmap(png)
    
            self.img=Image.open(imgName)
            self.img=np.array(self.img)
        def cv_qt(self, src):
            #src必须为bgr格式图像
            #src必须为bgr格式图像
            #src必须为bgr格式图像
            if len(src.shape)==2:
                src=np.expand_dims(src,axis=-1)
                src=np.tile(src,(1,1,3))
                h, w, d = src.shape
            else:h, w, d = src.shape
    
    
    
            bytesperline = d * w
            # self.src=cv.cvtColor(self.src,cv.COLOR_BGR2RGB)
            qt_image = QImage(src.data, w, h, bytesperline, QImage.Format_RGB888).rgbSwapped()
            return qt_image
    
    if __name__ == '__main__':
        QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)
        app=QApplication(sys.argv)
        my=My()
        my.show()
        sys.exit(app.exec_())
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69

    项目结构:
    在这里插入图片描述
    直接运行main.py即可弹出交互界面。
    项目下载地址:下载地址-列表第19

  • 相关阅读:
    为什么vue3要选用proxy,好处是什么?
    开箱报告,Simulink Toolbox库模块使用指南(七)——S-Fuction Builter模块
    jmeter实战
    基于微信小程序的茶叶在线商城系统(后台Java+Spring boot+VUE+MySQL)
    JS的数组与字符串方法脑图总结
    MyBatis学习:动态SQL中<if>标签的使用
    45.讲位图:如何实现网页爬虫中的URL去重功能
    05设计模式-建造型模式-建造者模式
    python反爬⾍策略应对
    详解自动化测试之 Selenium
  • 原文地址:https://blog.csdn.net/qq_45087786/article/details/136684411