• 【自学记录】【Pytorch2.0深度学习从零开始学 王晓华】第二章 深度学习环境搭建


    环境搭建

    参考的这篇帖子点我

    2.3 基于pytorch2.0的图像去噪

    疑问:
    1、莫非是我输出图像错了,总感觉这一章使用的训练集,训练的图像没有噪点。。。
    2、归一化处理测试样本,应该除以255吧?文心也说应该除以255,不知道源码里的512有什么含义。

    x_train = np.reshape(x_train_batch, [-1, 1, 28, 28])  #修正数据输入维度:([30596, 28, 28])
    #  归一化处理测试样本?????????
    x_train /= 512.
    
    • 1
    • 2
    • 3

    解决的问题:
    下面这行代码 ,"…“表示父目录,“…/ ” 表示返回上一级目录,【dataset】文件夹跟【第二章】文件夹并列,当我们在【第二章】文件夹下打开IDE,执行train.py 文件时,”…"便等同于“源码\第二章”,再执行以下语句,便可以找到【dataset】文件夹。

    x_train = np.load("../dataset/mnist/x_train.npy")
    
    • 1

    【dataset】文件夹跟【第二章】文件夹并列
    在【第二章】文件夹下打开IDE

    源码\第二章\train.py

    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = '0' #指定GPU编号
    import torch
    import numpy as np
    import unet
    import matplotlib.pyplot as plt
    from tqdm import tqdm
    
    batch_size = 320                        #设定每次训练的批次数
    epochs = 1024                           #设定训练次数
    
    #device = "cpu"                         #Pytorch的特性,需要指定计算的硬件,如果没有GPU的存在,就使用CPU进行计算
    device = "cuda"                         #在这里读者默认使用GPU,如果读者出现运行问题可以将其改成cpu模式
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model = unet.Unet()                     #导入Unet模型
    model = model.to(device)                #将计算模型传入GPU硬件等待计算
    #model = torch.compile(model)            #Pytorch2.0的特性,加速计算速度
    optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)   #设定优化函数/器【Adam】 梯度下降。。。的
    
    #载入数据
    x_train = np.load("../dataset/mnist/x_train.npy")
    y_train_label = np.load("../dataset/mnist/y_train_label.npy")
    
    x_train_batch = x_train
    x_train2=x_train
    
    
    # x_train_batch = []
    # for i in range(len(y_train_label)):
    #     if y_train_label[i] <= 10:                    #为了加速演示作者只对数据集中的小于2的数字,也就是01进行运行,读者可以自行增加训练个数
    #         x_train_batch.append(x_train[i])
    
    x_train = np.reshape(x_train_batch, [-1, 1, 28, 28])  #修正数据输入维度:([30596, 28, 28])
    #  归一化处理测试样本?????????
    x_train /= 512.
    train_length = len(x_train) * 20                       #增加数据的单次循环次数
    
    state_dict = torch.load("./saver/unet.pth")
    model.load_state_dict(state_dict)
    for epoch in range(epochs):
        train_num = train_length // batch_size             #计算有多少批次数
    
        train_loss = 0                                     #用于损失函数的统计
        optimizer.zero_grad()                               #对导数进行清零!!!!!!!!!!!
        for i in tqdm(range(train_num)):                    #开始循环训练
            x_imgs_batch = []                               #创建数据的临时存储位置
            x_step_batch = []
            y_batch = []
            # 对每个批次内的数据进行处理
            for b in range(batch_size):
                img = x_train[np.random.randint(x_train.shape[0])]  #提取单个图片内容
                x = img
                y = img
    
                x_imgs_batch.append(x)
                y_batch.append(y)
    
            #将批次数据转化为Pytorch对应的tensor格式并将其传入GPU中
            x_imgs_batch = torch.tensor(x_imgs_batch).float().to(device)
            y_batch = torch.tensor(y_batch).float().to(device)
    
    
            pred = model(x_imgs_batch)                      #对模型进行正向计算
            loss = torch.nn.MSELoss(reduction="sum")(pred, y_batch)*100.   #使用损失函数进行计算
    
            #这里读者记住下面就是固定格式,一般而言这样使用即可
           ###########################################3
            loss.backward()                                                     #损失值的反向传播
            optimizer.step()                                                    #对参数进行更新
    
            train_loss += loss.item()                                           #记录每个批次的损失值
        #计算并打印损失值
        train_loss /= train_num
        print("train_loss:", train_loss)                                                                                                                                                                                                        
        if epoch%6 == 0:
            torch.save(model.state_dict(),"./saver/unet.pth")#要么存这里,要么存内存里了,类里面了
    
        #下面是对数据进行打印
        ran_img=np.random.randint(x_train.shape[0])
        image = x_train[ran_img]                    #随机挑选一条数据进行计算
        
        plt.rcParams['font.sans-serif']=['SimHei']
        plt.rcParams['axes.unicode_minus'] = False
        plt.subplot(121)
        plt.title('图像原始结果')
        plt.imshow(x_train2[ran_img])
       
        image = np.reshape(image,[1,1,28,28])                                   #修正数据维度
    
        image = torch.tensor(image).float().to(device)                          #挑选的数据传入硬件中等待计算
        image = model(image)                                                    #使用模型对数据进行计算
    
        image = torch.reshape(image, shape=[28,28])                             #修正模型输出结果
        image = image.detach().cpu().numpy()                                    #将计算结果导入CPU中进行后续计算或者展示
     
        #展示或计算数据结果
        plt.subplot(122)
        plt.rcParams['font.sans-serif']=['SimHei']
        plt.rcParams['axes.unicode_minus'] = False
        plt.title('消除噪声后的结果')
        plt.imshow(image)
        plt.savefig(f"./img/img_{epoch}.jpg")
        plt.show()
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105

    结果:只训练了几次,左边是直接输出的训练集原始图像,这样输出的,不知道对不对
    请添加图片描述请添加图片描述

    请添加图片描述
    请添加图片描述

  • 相关阅读:
    好用的递归子查询
    《Linux运维实战:创建LVM挂载到指定目录》
    day012--mysql中的聚合函数
    【动态规划】392. 判断子序列、115. 不同的子序列
    TCP零基础详解
    [附源码]Python计算机毕业设计Django家庭整理服务管理系统
    Harmony Next 文件命令操作(发送、读取、媒体文件查询)
    【入门级小游戏】C语言数组函数:解析三(N)子棋
    从Mysql架构看一条查询sql的执行过程
    MYSQL | 数据库到底是怎么来的?
  • 原文地址:https://blog.csdn.net/weixin_43502713/article/details/137243493