• [PyTorch][chapter 53][Auto Encoder 实战]


    前言:

         结合手写数字识别的例子,实现以下AutoEncoder

         ae.py:  实现autoEncoder 网络

         main.py: 加载手写数字数据集,以及训练,验证,测试网络。

    左图:原图像

    右图:重构图像

     ----main-----

     每轮训练时间 : 91
    0 loss: 0.02758789248764515

     每轮训练时间 : 95
    1 loss: 0.024654878303408623

     每轮训练时间 : 149
    2 loss: 0.018874473869800568

    目录:

         1: AE 实现

         2: main 实现


    一  ae(AutoEncoder) 实现

      文件名: ae.py

                   模型的搭建

       注意点:

                手写数字数据集 提供了 标签y,但是AutoEncoder 网络不需要,

    它的标签就是输入的x, 需要重构本身

    自编码器(autoencoder, AE)是一类在半监督学习非监督学习中使用的人工神经网络(Artificial Neural Networks, ANNs),其功能是通过将输入信息作为学习目标,对输入信息进行表征学习(representation learning) [1-2]  

    编码器包含编码器(encoder)和解码器decoder)两部分 [2]  。按学习范式,自编码器可以被分为收缩自编码器(contractive autoencoder)、正则自编码器(regularized autoencoder)和变分自编码器(Variational AutoEncoder, VAE),其中前两者是判别模型、后者是生成模型 [2]  。按构筑类型,自编码器可以是前馈结构或递归结构的神经网络。

    自编码器具有一般意义上表征学习算法的功能,被应用于降维(dimensionality reduction)和异常值检测(anomaly detection) [2]  。包含卷积层构筑的自编码器可被应用于计算机视觉问题,包括图像降噪(image denoising) [3]  、神经风格迁移(neural style transfer)等 [4]  。

       

    1. # -*- coding: utf-8 -*-
    2. """
    3. Created on Wed Aug 30 14:19:19 2023
    4. @author: chengxf2
    5. """
    6. import torch
    7. from torch import nn
    8. #ae: AutoEncoder
    9. class AE(nn.Module):
    10. def __init__(self,hidden_size=10):
    11. super(AE, self).__init__()
    12. self.encoder = nn.Sequential(
    13. nn.Linear(in_features=784, out_features=256),
    14. nn.ReLU(),
    15. nn.Linear(in_features=256, out_features=128),
    16. nn.ReLU(),
    17. nn.Linear(in_features=128, out_features=64),
    18. nn.ReLU(),
    19. nn.Linear(in_features=64, out_features=hidden_size),
    20. nn.ReLU()
    21. )
    22. # hidden [batch_size, 10]
    23. self.decoder = nn.Sequential(
    24. nn.Linear(in_features=hidden_size, out_features=64),
    25. nn.ReLU(),
    26. nn.Linear(in_features=64, out_features=128),
    27. nn.ReLU(),
    28. nn.Linear(in_features=128, out_features=256),
    29. nn.ReLU(),
    30. nn.Linear(in_features=256, out_features=784),
    31. nn.Sigmoid()
    32. )
    33. def forward(self, x):
    34. '''
    35. param x:[batch, 1,28,28]
    36. return
    37. '''
    38. m= x.size(0)
    39. x = x.view(m, 784)
    40. hidden= self.encoder(x)
    41. x = self.decoder(hidden)
    42. #reshape
    43. x = x.view(m,1,28,28)
    44. return x

    二 main 实现

      文件名: main.py

      作用:

          加载数据集

         训练模型

         测试模型泛化能力

    1. # -*- coding: utf-8 -*-
    2. """
    3. Created on Wed Aug 30 14:24:10 2023
    4. @author: chengxf2
    5. """
    6. import torch
    7. from torch.utils.data import DataLoader
    8. from torchvision import transforms, datasets
    9. import time
    10. from torch import optim,nn
    11. from ae import AE
    12. import visdom
    13. def main():
    14. batchNum = 32
    15. lr = 1e-3
    16. epochs = 20
    17. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    18. torch.manual_seed(1234)
    19. viz = visdom.Visdom()
    20. viz.line([0],[-1],win='train_loss',opts =dict(title='train acc'))
    21. tf= transforms.Compose([ transforms.ToTensor()])
    22. mnist_train = datasets.MNIST('mnist',True,transform= tf,download=True)
    23. train_data = DataLoader(mnist_train, batch_size=batchNum, shuffle=True)
    24. mnist_test = datasets.MNIST('mnist',False,transform= tf,download=True)
    25. test_data = DataLoader(mnist_test, batch_size=batchNum, shuffle=True)
    26. global_step =0
    27. model =AE().to(device)
    28. criteon = nn.MSELoss().to(device) #损失函数
    29. optimizer = optim.Adam(model.parameters(),lr=lr) #梯度更新规则
    30. print("\n ----main-----")
    31. for epoch in range(epochs):
    32. start = time.perf_counter()
    33. for step ,(x,y) in enumerate(train_data):
    34. #[b,1,28,28]
    35. x = x.to(device)
    36. x_hat = model(x)
    37. loss = criteon(x_hat, x)
    38. #backprop
    39. optimizer.zero_grad()
    40. loss.backward()
    41. optimizer.step()
    42. viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')
    43. global_step +=1
    44. end = time.perf_counter()
    45. interval = end - start
    46. print("\n 每轮训练时间 :",int(interval))
    47. print(epoch, 'loss:',loss.item())
    48. x,target = iter(test_data).next()
    49. x = x.to(device)
    50. with torch.no_grad():
    51. x_hat = model(x)
    52. tip = 'hat'+str(epoch)
    53. viz.images(x,nrow=8, win='x',opts=dict(title='x'))
    54. viz.images(x_hat,nrow=8, win='x_hat',opts=dict(title=tip))
    55. if __name__ == '__main__':
    56. main()

  • 相关阅读:
    设计模式之抽象工厂模式(学习笔记)
    利用JDBC及Servlet实现对信息录入注册功能的实现
    笔试面试相关记录(6)
    2024-04-02 问AI:介绍一下深度学习中的 “迁移学习”
    3.2 C++高级编程_抽象类界面
    【论文解读】RLAIF基于人工智能反馈的强化学习
    Unity技术手册-UGUI零基础详细教程-Toggle切换
    【无题】仙女话术
    炫酷的表白烟花 html+css+js实现的表白烟花特效(程序员专属情人节表白网站)
    如何快速实现一个颜色选择器
  • 原文地址:https://blog.csdn.net/chengxf2/article/details/132583350