• [PyTorch][chapter 53][Auto Encoder 实战]


    前言:

         结合手写数字识别的例子,实现以下AutoEncoder

         ae.py:  实现autoEncoder 网络

         main.py: 加载手写数字数据集,以及训练,验证,测试网络。

    左图:原图像

    右图:重构图像

     ----main-----

     每轮训练时间 : 91
    0 loss: 0.02758789248764515

     每轮训练时间 : 95
    1 loss: 0.024654878303408623

     每轮训练时间 : 149
    2 loss: 0.018874473869800568

    目录:

         1: AE 实现

         2: main 实现


    一  ae(AutoEncoder) 实现

      文件名: ae.py

                   模型的搭建

       注意点:

                手写数字数据集 提供了 标签y,但是AutoEncoder 网络不需要,

    它的标签就是输入的x, 需要重构本身

    自编码器(autoencoder, AE)是一类在半监督学习非监督学习中使用的人工神经网络(Artificial Neural Networks, ANNs),其功能是通过将输入信息作为学习目标,对输入信息进行表征学习(representation learning) [1-2]  

    编码器包含编码器(encoder)和解码器decoder)两部分 [2]  。按学习范式,自编码器可以被分为收缩自编码器(contractive autoencoder)、正则自编码器(regularized autoencoder)和变分自编码器(Variational AutoEncoder, VAE),其中前两者是判别模型、后者是生成模型 [2]  。按构筑类型,自编码器可以是前馈结构或递归结构的神经网络。

    自编码器具有一般意义上表征学习算法的功能,被应用于降维(dimensionality reduction)和异常值检测(anomaly detection) [2]  。包含卷积层构筑的自编码器可被应用于计算机视觉问题,包括图像降噪(image denoising) [3]  、神经风格迁移(neural style transfer)等 [4]  。

       

    1. # -*- coding: utf-8 -*-
    2. """
    3. Created on Wed Aug 30 14:19:19 2023
    4. @author: chengxf2
    5. """
    6. import torch
    7. from torch import nn
    8. #ae: AutoEncoder
    9. class AE(nn.Module):
    10. def __init__(self,hidden_size=10):
    11. super(AE, self).__init__()
    12. self.encoder = nn.Sequential(
    13. nn.Linear(in_features=784, out_features=256),
    14. nn.ReLU(),
    15. nn.Linear(in_features=256, out_features=128),
    16. nn.ReLU(),
    17. nn.Linear(in_features=128, out_features=64),
    18. nn.ReLU(),
    19. nn.Linear(in_features=64, out_features=hidden_size),
    20. nn.ReLU()
    21. )
    22. # hidden [batch_size, 10]
    23. self.decoder = nn.Sequential(
    24. nn.Linear(in_features=hidden_size, out_features=64),
    25. nn.ReLU(),
    26. nn.Linear(in_features=64, out_features=128),
    27. nn.ReLU(),
    28. nn.Linear(in_features=128, out_features=256),
    29. nn.ReLU(),
    30. nn.Linear(in_features=256, out_features=784),
    31. nn.Sigmoid()
    32. )
    33. def forward(self, x):
    34. '''
    35. param x:[batch, 1,28,28]
    36. return
    37. '''
    38. m= x.size(0)
    39. x = x.view(m, 784)
    40. hidden= self.encoder(x)
    41. x = self.decoder(hidden)
    42. #reshape
    43. x = x.view(m,1,28,28)
    44. return x

    二 main 实现

      文件名: main.py

      作用:

          加载数据集

         训练模型

         测试模型泛化能力

    1. # -*- coding: utf-8 -*-
    2. """
    3. Created on Wed Aug 30 14:24:10 2023
    4. @author: chengxf2
    5. """
    6. import torch
    7. from torch.utils.data import DataLoader
    8. from torchvision import transforms, datasets
    9. import time
    10. from torch import optim,nn
    11. from ae import AE
    12. import visdom
    13. def main():
    14. batchNum = 32
    15. lr = 1e-3
    16. epochs = 20
    17. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    18. torch.manual_seed(1234)
    19. viz = visdom.Visdom()
    20. viz.line([0],[-1],win='train_loss',opts =dict(title='train acc'))
    21. tf= transforms.Compose([ transforms.ToTensor()])
    22. mnist_train = datasets.MNIST('mnist',True,transform= tf,download=True)
    23. train_data = DataLoader(mnist_train, batch_size=batchNum, shuffle=True)
    24. mnist_test = datasets.MNIST('mnist',False,transform= tf,download=True)
    25. test_data = DataLoader(mnist_test, batch_size=batchNum, shuffle=True)
    26. global_step =0
    27. model =AE().to(device)
    28. criteon = nn.MSELoss().to(device) #损失函数
    29. optimizer = optim.Adam(model.parameters(),lr=lr) #梯度更新规则
    30. print("\n ----main-----")
    31. for epoch in range(epochs):
    32. start = time.perf_counter()
    33. for step ,(x,y) in enumerate(train_data):
    34. #[b,1,28,28]
    35. x = x.to(device)
    36. x_hat = model(x)
    37. loss = criteon(x_hat, x)
    38. #backprop
    39. optimizer.zero_grad()
    40. loss.backward()
    41. optimizer.step()
    42. viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')
    43. global_step +=1
    44. end = time.perf_counter()
    45. interval = end - start
    46. print("\n 每轮训练时间 :",int(interval))
    47. print(epoch, 'loss:',loss.item())
    48. x,target = iter(test_data).next()
    49. x = x.to(device)
    50. with torch.no_grad():
    51. x_hat = model(x)
    52. tip = 'hat'+str(epoch)
    53. viz.images(x,nrow=8, win='x',opts=dict(title='x'))
    54. viz.images(x_hat,nrow=8, win='x_hat',opts=dict(title=tip))
    55. if __name__ == '__main__':
    56. main()

  • 相关阅读:
    服务器简单介绍
    冷热电气多能互补的微能源网鲁棒优化调度附Matlab代码
    Mybatis实战练习二【查询详情】
    touchGFX综合学习十三、基于cubeMX、正点原子H750开发版、RGB4.3寸屏移植touchGFX完整教程+工程(一)
    Strimzi Kafka Bridge(桥接)实战之二:生产和发送消息
    漫谈:C语言 C++ 所有编程语言 =和==的麻烦
    【Visual Leak Detector】源码文件概览
    大数据_数据中台建设的成熟度评估模型
    可变形卷积 DeformConv2d
    C语言天花板——指针(初阶)
  • 原文地址:https://blog.csdn.net/chengxf2/article/details/132583350