• 超分辨率重建——SESR网络训练并推理测试(详细图文教程)


    最近学了一个超轻量化的超分辨率重建网络SESR,效果还不错。
    在这里插入图片描述

    一、 源码包

    SESR官网的地址为:官网

    我自己调整过的源码包获取方法文章末扫码到公众号中回复关键字:超分辨率重建SESR。获取下载链接。

    论文地址:论文

    源码包推荐使用我给的,我注释过很多地方,看起来不吃力,且我自己添加了推理测试脚本。

    下载好源码包解压后的样子如下:

    在这里插入图片描述

    二、 数据集的准备

    获取数据集可以有两种方法:

    2.1 官网下载

    直接运行源码包中的脚本文件train.py,会自动先下载div2k数据集,但是下载的非常慢,高分辨率数据集有3G多,容易下蹦了。默认会下载到系统C盘下,具体路径为:

    C:\Users\Administrator\tensorflow_datasets\downloads,每次下载失败后再次运行又会重新生成序列码并重新下载,很麻烦。

    如下:
    在这里插入图片描述

    DIV2K数据集下载链接:下载DIV2K

    四个测试集(Set5,Set4,B100,Urban100)下载链接:测试集下载

    2.2 网盘下载

    我提供了一个我下载好并整理好的数据集,文件存放对应关系我都整理好了,学者可以直接下载导入使用,下载链接为:网盘下载 ,提取码为:32d4

    三、 训练环境配置

    该网络结构是在TensorFlow框架下运行的TensorFlow版本是2.3,还有一个包的版本是tensorflow_datasets==4.1,Pyhton3.6版本,额。。。。。。。。。。。。。。。。。。

    踩了很多坑,最后我自己调通的版本是TensorFlow-gpu2.9,Python 3.7版本,tensorflow_datasets4.8.2,如下:

    在这里插入图片描述

    安装好TensorFlow-GPU后先测试一下能不能正常调用GPU,测试方法参考:添加链接描述

    四、训练

    4.1 修改配置参数

    打开train.py文件,里面有些配置参数根据自己电脑情况修改:

    在这里插入图片描述

    train.py脚本中对应上图修改的地方如下:

    在这里插入图片描述

    4.2 导入数据集

    下载好我提供的数据集后,解压好讲整个tensorflow_datasets文件夹放到data文件夹中,并将tensorflow_datasets文件夹所在路径赋值给变量data_dir,代码中具体的修改地方如下:

    在这里插入图片描述

    4.3 2倍超分网络训练

    根据自己需求选择要训练深度:

    4.3.1 训练SESR-M5网络

    其中m = 5,f = 16,feature_size = 256,具有折叠线性块:

    python train.py
    
    • 1

    4.3.2 训练SESR-M5网络

    m = 5,f = 16,feature_size = 256,扩展线性块:

    python train.py --linear_block_type expanded
    
    • 1

    4.3.3 训练SESR-M11网络

    其中m = 11,f = 16,feature_size = 64,具有折叠线性块:

    python train.py --m 11 --feature_size 64
    
    • 1

    4.4.4 训练SESR-XL网络

    其中m = 11,f = 16,feature_size = 64,具有折叠线性块:

    python train.py --m 11 --int_features 32 --feature_size 64
    
    • 1

    4.4 2倍超分网络模型

    通过上面步骤训练好后会在logs文件中自动保存权重文件和模型,我自己训练好的模型权重文件都打包在源码包了,学者可以直接使用,如下:

    在这里插入图片描述
    上面各个文件代表内容为:

    .pb:表示protocol buffers,是模型结构和参数的二进制序列化文件。存储了模型的网络结构,变量,权重等信息。是模型persist的主要文件。

    .data-00000-of-00001:存储了模型变量的取值,即模型权重参数的值。模型训练完成后保存的权重。

    .index:索引文件,存放了参数tensor的meta信息,如tensor名称、维度等。用于定位data文件中的tensor数据。

    checkpoints文件:存储模型训练过程中的参数,用于恢复训练。

    4.5 修改模型保存格式

    上面是默认的保存方式,学长如果需要其他格式的自己修改保存方法,具体修改地方如下:

    在这里插入图片描述

    4.6 4倍超分网络训练

    4倍超分网络得在2倍超分模型基础上训练才行,网络深度自己选择:

    4.6.1 训练SESR-M5网络

    其中m = 5,f = 16,feature_size = 256,具有折叠线性块:

    python train.py --scale 4
    
    • 1

    4.6.2 训练SESR-M5网络

    m = 5,f = 16,feature_size = 256,扩展线性块:

    python train.py --linear_block_type expanded --scale 4
    
    • 1

    4.6.3 训练SESR-M11网络

    其中m = 11,f = 16,feature_size = 64,具有折叠线性块:

    python train.py --m 11 --feature_size 64 --scale 4
    
    • 1

    4.6.4 训练SESR-XL网络

    其中m = 11,f = 16,feature_size = 64,具有折叠线性块:

    python train.py --m 11 --int_features 32 --feature_size 64 --scale 4
    
    • 1

    4.7 4倍超分网络模型

    训练好后,模型会自动保存在logs文件中,如下:

    在这里插入图片描述

    五、量化训练

    运行以下命令,在训练时对网络进行调试,并生成TFLITE(用于x2 SISR、SESR-M5网络):

    python train.py --quant_W --quant_A --gen_tflite
    
    • 1

    5.1 量化训练模型

    训练好后自动保存在logs/x2_models文件下,如下:

    在这里插入图片描述

    六、模型推理测试

    推理脚本是我自己写的,具体使用如下,根据需求自行选择:

    在这里插入图片描述在这里插入图片描述
    在这里插入图片描述

    七、超分效果

    在这里插入图片描述
    在这里插入图片描述

    八、总结

    以上就是超分辨率重建——SESR网络训练并推理测试的详细图文教程,总结不易,给个三连多多支持,谢谢!欢迎留言讨论。

  • 相关阅读:
    手机号码骚扰拦截
    论文笔记:A survey of deep nonnegative matrix factorization
    Android使用Zxing库生成PDF417扫描后多一个字符A
    【Python基础入门3】转义字符和原字符
    set容器
    普林斯顿10分钟剧本创作比赛
    ceph分布式存储
    [自制操作系统] 第12回 实现中断代码
    新手小白看过来——带你快速入门跨境电商
    UniApp 中的路由魔法:玩转页面导航与跳转
  • 原文地址:https://blog.csdn.net/qq_40280673/article/details/134062403