• [PyTorch][chapter 53][Auto Encoder 实战]


    前言:

         结合手写数字识别的例子,实现以下AutoEncoder

         ae.py:  实现autoEncoder 网络

         main.py: 加载手写数字数据集,以及训练,验证,测试网络。

    左图:原图像

    右图:重构图像

     ----main-----

     每轮训练时间 : 91
    0 loss: 0.02758789248764515

     每轮训练时间 : 95
    1 loss: 0.024654878303408623

     每轮训练时间 : 149
    2 loss: 0.018874473869800568

    目录:

         1: AE 实现

         2: main 实现


    一  ae(AutoEncoder) 实现

      文件名: ae.py

                   模型的搭建

       注意点:

                手写数字数据集 提供了 标签y,但是AutoEncoder 网络不需要,

    它的标签就是输入的x, 需要重构本身

    自编码器(autoencoder, AE)是一类在半监督学习非监督学习中使用的人工神经网络(Artificial Neural Networks, ANNs),其功能是通过将输入信息作为学习目标,对输入信息进行表征学习(representation learning) [1-2]  

    编码器包含编码器(encoder)和解码器decoder)两部分 [2]  。按学习范式,自编码器可以被分为收缩自编码器(contractive autoencoder)、正则自编码器(regularized autoencoder)和变分自编码器(Variational AutoEncoder, VAE),其中前两者是判别模型、后者是生成模型 [2]  。按构筑类型,自编码器可以是前馈结构或递归结构的神经网络。

    自编码器具有一般意义上表征学习算法的功能,被应用于降维(dimensionality reduction)和异常值检测(anomaly detection) [2]  。包含卷积层构筑的自编码器可被应用于计算机视觉问题,包括图像降噪(image denoising) [3]  、神经风格迁移(neural style transfer)等 [4]  。

       

    1. # -*- coding: utf-8 -*-
    2. """
    3. Created on Wed Aug 30 14:19:19 2023
    4. @author: chengxf2
    5. """
    6. import torch
    7. from torch import nn
    8. #ae: AutoEncoder
    9. class AE(nn.Module):
    10. def __init__(self,hidden_size=10):
    11. super(AE, self).__init__()
    12. self.encoder = nn.Sequential(
    13. nn.Linear(in_features=784, out_features=256),
    14. nn.ReLU(),
    15. nn.Linear(in_features=256, out_features=128),
    16. nn.ReLU(),
    17. nn.Linear(in_features=128, out_features=64),
    18. nn.ReLU(),
    19. nn.Linear(in_features=64, out_features=hidden_size),
    20. nn.ReLU()
    21. )
    22. # hidden [batch_size, 10]
    23. self.decoder = nn.Sequential(
    24. nn.Linear(in_features=hidden_size, out_features=64),
    25. nn.ReLU(),
    26. nn.Linear(in_features=64, out_features=128),
    27. nn.ReLU(),
    28. nn.Linear(in_features=128, out_features=256),
    29. nn.ReLU(),
    30. nn.Linear(in_features=256, out_features=784),
    31. nn.Sigmoid()
    32. )
    33. def forward(self, x):
    34. '''
    35. param x:[batch, 1,28,28]
    36. return
    37. '''
    38. m= x.size(0)
    39. x = x.view(m, 784)
    40. hidden= self.encoder(x)
    41. x = self.decoder(hidden)
    42. #reshape
    43. x = x.view(m,1,28,28)
    44. return x

    二 main 实现

      文件名: main.py

      作用:

          加载数据集

         训练模型

         测试模型泛化能力

    1. # -*- coding: utf-8 -*-
    2. """
    3. Created on Wed Aug 30 14:24:10 2023
    4. @author: chengxf2
    5. """
    6. import torch
    7. from torch.utils.data import DataLoader
    8. from torchvision import transforms, datasets
    9. import time
    10. from torch import optim,nn
    11. from ae import AE
    12. import visdom
    13. def main():
    14. batchNum = 32
    15. lr = 1e-3
    16. epochs = 20
    17. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    18. torch.manual_seed(1234)
    19. viz = visdom.Visdom()
    20. viz.line([0],[-1],win='train_loss',opts =dict(title='train acc'))
    21. tf= transforms.Compose([ transforms.ToTensor()])
    22. mnist_train = datasets.MNIST('mnist',True,transform= tf,download=True)
    23. train_data = DataLoader(mnist_train, batch_size=batchNum, shuffle=True)
    24. mnist_test = datasets.MNIST('mnist',False,transform= tf,download=True)
    25. test_data = DataLoader(mnist_test, batch_size=batchNum, shuffle=True)
    26. global_step =0
    27. model =AE().to(device)
    28. criteon = nn.MSELoss().to(device) #损失函数
    29. optimizer = optim.Adam(model.parameters(),lr=lr) #梯度更新规则
    30. print("\n ----main-----")
    31. for epoch in range(epochs):
    32. start = time.perf_counter()
    33. for step ,(x,y) in enumerate(train_data):
    34. #[b,1,28,28]
    35. x = x.to(device)
    36. x_hat = model(x)
    37. loss = criteon(x_hat, x)
    38. #backprop
    39. optimizer.zero_grad()
    40. loss.backward()
    41. optimizer.step()
    42. viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')
    43. global_step +=1
    44. end = time.perf_counter()
    45. interval = end - start
    46. print("\n 每轮训练时间 :",int(interval))
    47. print(epoch, 'loss:',loss.item())
    48. x,target = iter(test_data).next()
    49. x = x.to(device)
    50. with torch.no_grad():
    51. x_hat = model(x)
    52. tip = 'hat'+str(epoch)
    53. viz.images(x,nrow=8, win='x',opts=dict(title='x'))
    54. viz.images(x_hat,nrow=8, win='x_hat',opts=dict(title=tip))
    55. if __name__ == '__main__':
    56. main()

  • 相关阅读:
    WordPress主题开发(四)之—— 模板文件
    LeetCode707:设计链表
    2023年05月 Python(五级)真题解析#中国电子学会#全国青少年软件编程等级考试
    精致小巧,支持苹果 Find My 的Chipolo ONE Spot
    python3-常用数据结构
    java计算机毕业设计销售人员绩效管理系统源码+系统+数据库+lw文档(1)
    Arduino开发板使用I2C SSD1306 OLED显示屏的方法
    警惕!外贸常见的一些骗局!
    opencv 形态学转换
    沈阳陪诊系统|沈阳陪诊系统开发|沈阳陪诊系统功能和优势
  • 原文地址:https://blog.csdn.net/chengxf2/article/details/132583350