• [PyTorch][chapter 53][Auto Encoder 实战]


    前言:

         结合手写数字识别的例子,实现以下AutoEncoder

         ae.py:  实现autoEncoder 网络

         main.py: 加载手写数字数据集,以及训练,验证,测试网络。

    左图:原图像

    右图:重构图像

     ----main-----

     每轮训练时间 : 91
    0 loss: 0.02758789248764515

     每轮训练时间 : 95
    1 loss: 0.024654878303408623

     每轮训练时间 : 149
    2 loss: 0.018874473869800568

    目录:

         1: AE 实现

         2: main 实现


    一  ae(AutoEncoder) 实现

      文件名: ae.py

                   模型的搭建

       注意点:

                手写数字数据集 提供了 标签y,但是AutoEncoder 网络不需要,

    它的标签就是输入的x, 需要重构本身

    自编码器(autoencoder, AE)是一类在半监督学习非监督学习中使用的人工神经网络(Artificial Neural Networks, ANNs),其功能是通过将输入信息作为学习目标,对输入信息进行表征学习(representation learning) [1-2]  

    编码器包含编码器(encoder)和解码器decoder)两部分 [2]  。按学习范式,自编码器可以被分为收缩自编码器(contractive autoencoder)、正则自编码器(regularized autoencoder)和变分自编码器(Variational AutoEncoder, VAE),其中前两者是判别模型、后者是生成模型 [2]  。按构筑类型,自编码器可以是前馈结构或递归结构的神经网络。

    自编码器具有一般意义上表征学习算法的功能,被应用于降维(dimensionality reduction)和异常值检测(anomaly detection) [2]  。包含卷积层构筑的自编码器可被应用于计算机视觉问题,包括图像降噪(image denoising) [3]  、神经风格迁移(neural style transfer)等 [4]  。

       

    1. # -*- coding: utf-8 -*-
    2. """
    3. Created on Wed Aug 30 14:19:19 2023
    4. @author: chengxf2
    5. """
    6. import torch
    7. from torch import nn
    8. #ae: AutoEncoder
    9. class AE(nn.Module):
    10. def __init__(self,hidden_size=10):
    11. super(AE, self).__init__()
    12. self.encoder = nn.Sequential(
    13. nn.Linear(in_features=784, out_features=256),
    14. nn.ReLU(),
    15. nn.Linear(in_features=256, out_features=128),
    16. nn.ReLU(),
    17. nn.Linear(in_features=128, out_features=64),
    18. nn.ReLU(),
    19. nn.Linear(in_features=64, out_features=hidden_size),
    20. nn.ReLU()
    21. )
    22. # hidden [batch_size, 10]
    23. self.decoder = nn.Sequential(
    24. nn.Linear(in_features=hidden_size, out_features=64),
    25. nn.ReLU(),
    26. nn.Linear(in_features=64, out_features=128),
    27. nn.ReLU(),
    28. nn.Linear(in_features=128, out_features=256),
    29. nn.ReLU(),
    30. nn.Linear(in_features=256, out_features=784),
    31. nn.Sigmoid()
    32. )
    33. def forward(self, x):
    34. '''
    35. param x:[batch, 1,28,28]
    36. return
    37. '''
    38. m= x.size(0)
    39. x = x.view(m, 784)
    40. hidden= self.encoder(x)
    41. x = self.decoder(hidden)
    42. #reshape
    43. x = x.view(m,1,28,28)
    44. return x

    二 main 实现

      文件名: main.py

      作用:

          加载数据集

         训练模型

         测试模型泛化能力

    1. # -*- coding: utf-8 -*-
    2. """
    3. Created on Wed Aug 30 14:24:10 2023
    4. @author: chengxf2
    5. """
    6. import torch
    7. from torch.utils.data import DataLoader
    8. from torchvision import transforms, datasets
    9. import time
    10. from torch import optim,nn
    11. from ae import AE
    12. import visdom
    13. def main():
    14. batchNum = 32
    15. lr = 1e-3
    16. epochs = 20
    17. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    18. torch.manual_seed(1234)
    19. viz = visdom.Visdom()
    20. viz.line([0],[-1],win='train_loss',opts =dict(title='train acc'))
    21. tf= transforms.Compose([ transforms.ToTensor()])
    22. mnist_train = datasets.MNIST('mnist',True,transform= tf,download=True)
    23. train_data = DataLoader(mnist_train, batch_size=batchNum, shuffle=True)
    24. mnist_test = datasets.MNIST('mnist',False,transform= tf,download=True)
    25. test_data = DataLoader(mnist_test, batch_size=batchNum, shuffle=True)
    26. global_step =0
    27. model =AE().to(device)
    28. criteon = nn.MSELoss().to(device) #损失函数
    29. optimizer = optim.Adam(model.parameters(),lr=lr) #梯度更新规则
    30. print("\n ----main-----")
    31. for epoch in range(epochs):
    32. start = time.perf_counter()
    33. for step ,(x,y) in enumerate(train_data):
    34. #[b,1,28,28]
    35. x = x.to(device)
    36. x_hat = model(x)
    37. loss = criteon(x_hat, x)
    38. #backprop
    39. optimizer.zero_grad()
    40. loss.backward()
    41. optimizer.step()
    42. viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')
    43. global_step +=1
    44. end = time.perf_counter()
    45. interval = end - start
    46. print("\n 每轮训练时间 :",int(interval))
    47. print(epoch, 'loss:',loss.item())
    48. x,target = iter(test_data).next()
    49. x = x.to(device)
    50. with torch.no_grad():
    51. x_hat = model(x)
    52. tip = 'hat'+str(epoch)
    53. viz.images(x,nrow=8, win='x',opts=dict(title='x'))
    54. viz.images(x_hat,nrow=8, win='x_hat',opts=dict(title=tip))
    55. if __name__ == '__main__':
    56. main()

  • 相关阅读:
    交通地理信息系统实习教程(二)
    多目标优化算法:基于非支配排序的霸王龙优化算法(NSTROA)MATLAB
    iOS应用程序数据保护:如何保护iOS应用程序中的图片、资源和敏感数据
    发版检查list
    06.JAVAEE之线程4
    应用现代化产业联盟,正式成立
    水厂消毒的设施设备有哪些
    隧道代理vs普通代理:哪种更适合您的爬虫应用?
    C++ 15:虚表,虚函数,多态,指针
    如何在 Endless OS 上安装 ONLYOFFICE 桌面编辑器 7.5
  • 原文地址:https://blog.csdn.net/chengxf2/article/details/132583350