• 图像超分辨率模型:Real-ESRGAN | 论文阅读+实战记录


    前言

    最近需要一个超分的模型,经过调研准备用 Real-ESRGAN。特此记录论文阅读和实战过程。

    论文阅读

    论文地址:Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data
    Github:https://github.com/xinntao/Real-ESRGAN
    参考视频:https://www.bilibili.com/video/BV14541117y6

    主要贡献:

    • 提出一种高阶退化过程(high-order degradation process)来模拟实际退化,并利用 sinc 滤波器给训练图片添加 Ringing artifacts(振铃伪影,周围震荡波的感觉)和 Overshoot artifacts(过冲伪影,如白边)构造训练集
    • 用 U-net 而不是 VGG 作为 GAN 的 Discriminator,提高鉴别器能力并稳定训练动态
    • Real-ESRGAN 性能更优,效果更好

    效果对比:
    在这里插入图片描述

    数据集构建:
    在这里插入图片描述
    二阶退化。两阶段分别经过虚化、下采样、加噪声、JPEG压缩。其中 sinc filter 用来添加伪影。

    伪影类型实例
    振铃伪影在这里插入图片描述
    过冲伪影在这里插入图片描述

    Real-ESRGAN 模型结构:

    • Generator:结构和 ESRGAN 一样,不过他做了个 Pixel Unshuffle,减少图像长宽,增加 channel size。送到一个 RRDB 的残差网络最后上采样得到输出。
      在这里插入图片描述
    • Discriminator:用的 U-Net(下采样和上采样的feature之间有连接,能学习局部纹理信息),不像最初的 GAN 判别器输出 0/1(即全局是否合格),他这个输出是和原图size一致,其中每个像素点的值衡量真实性的 feedback(即各局部是否合格)。此外,用了 spectral norm(谱归一化),能提升训练稳定性,减少伪影。
      在这里插入图片描述

    两阶段模型训练:
    首先,用 L1 loss 训练一个小网络(PSNR-oriented model)Real-ESRNet;然后,用其初始化 Generator 的权重,并用 L1 loss、perceptual loss、GAN loss 的结合训练最终的模型。

    此外,还对 ground-truth 加了锐化再训了个 Real-ESRGAN+,能提升图像生成的锐度,但不会增加伪影。

    实战记录

    git clone https://github.com/xinntao/Real-ESRGAN.git
    cd Real-ESRGAN
    # Install basicsr - https://github.com/xinntao/BasicSR
    # We use BasicSR for both training and inference
    pip install basicsr
    # facexlib and gfpgan are for face enhancement
    pip install facexlib
    pip install gfpgan
    pip install -r requirements.txt
    python setup.py develop
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    需要用到 realesrgan 模块的 RealESRGANer 类。以4倍超分为例:

    import os
    import cv2
    import torch
    import numpy as np
    from PIL import Image
    from basicsr.archs.rrdbnet_arch import RRDBNet
    from realesrgan import RealESRGANer
    
    ckpt_path = "./checkpoints/real-esrgan"
    model_path = os.path.join(ckpt_path, "RealESRGAN_x4plus.pth")
    model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
    netscale = 4
    
    upsampler = RealESRGANer(
    	scale=netscale,
    	model_path=self.model_path,
    	dni_weight=self.dni_weight,
    	model=self.model,
    	tile=0,                     # Tile size, 0 for no tile during testing
    	tile_pad=10,                # Tile padding
    	pre_pad=0,                  # Pre padding size at each border
    	half=not fp16,
    	device=device)
    
    def enhance(image, width, height):
        """
            image: PIL Image Obj
            输出: PIL Image Obj
        """
        try:
            image_cv2 = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
            output, _ = upsampler.enhance(image_cv2, outscale=self.outscale)  # _ is RGB/RGBA
            image_pil = Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB)).resize((width, height)).convert('RGB')
            # print(output, _, image_pil)
            return image_pil
        except Exception as e:
            print("enhance Exception: ", e)
        finally:
            torch.cuda.empty_cache()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39

    测试(超分前):
    在这里插入图片描述

    超分后:
    在这里插入图片描述

  • 相关阅读:
    三、nginx设置浏览器缓存[expires]
    01-windows栈溢出
    Kafka topic分区增加副本
    Vue子组件向父组件数据传递,非父子组件数据传递,单向数据流&组件双向数据共享
    RabbitMQ从0到1完整学习笔记二:《高级篇》
    IM即时通讯开发如何实现Android版智能心跳机制
    数据库测试技术点
    描述Java中的备忘录模式。
    PyTorch笔记 - Attention Is All You Need (2)
    实验32:气压温度传感器实验
  • 原文地址:https://blog.csdn.net/muyao987/article/details/127960309