• 【PyTorch][chapter 20][李宏毅深度学习]【无监督学习][ GAN]【实战】


    前言

     本篇主要是结合手写数字例子,结合PyTorch 介绍一下Gan 实战

    第一轮训练效果

    第20轮训练效果,已经可以生成数字了

    68 轮


    目录: 

    1.   谷歌云服务器(Google Colab)
    2.   整体训练流程
    3.   Python 代码

    一  谷歌云服务器(Google Colab)

         个人用的一直是联想小新笔记本,虽然非常稳定方便。但是现在跑深度学习,性能确实有点跟不上. 

       1.1    打开谷歌云服务器(Google Colab)

          https://colab.research.google.com/

        1. 2  新建笔记

                     

    1

     1.4  选择T4GPU 

    1.5  点击运行按钮

    可以看到当前硬件的情况

         


    二  整体训练流程


    三    PyTorch 例子

    1. # -*- coding: utf-8 -*-
    2. """
    3. Created on Fri Mar 1 13:27:49 2024
    4. @author: chengxf2
    5. """
    6. import torch.optim as optim #优化器
    7. import numpy as np
    8. import matplotlib.pyplot as plt
    9. import torchvision
    10. from torchvision import transforms
    11. import torch
    12. import torch.nn as nn
    13. #第一步加载手写数字集
    14. def loadData():
    15. #同时归一化数据集(-1,1)
    16. style = transforms.Compose([
    17. transforms.ToTensor(), #0-1 归一化0-1, channel,height,width
    18. transforms.Normalize(mean=0.5, std=0.5) #变成了-1,1
    19. ]
    20. )
    21. trainData = torchvision.datasets.MNIST('data',
    22. train=True,
    23. transform=style,
    24. download=True)
    25. dataloader = torch.utils.data.DataLoader(trainData,
    26. batch_size= 16,
    27. shuffle=True)
    28. imgs,_ = next(iter(dataloader))
    29. #torch.Size([64, 1, 28, 28])
    30. print("\n imgs shape ",imgs.shape)
    31. return dataloader
    32. class Generator(nn.Module):
    33. '''
    34. 定义生成器
    35. 输入:
    36. z 随机噪声[batch, input_size]
    37. 输出:
    38. x: 图片 [batch, height, width, channel]
    39. '''
    40. def __init__(self,input_size):
    41. super(Generator,self).__init__()
    42. self.net = nn.Sequential(
    43. nn.Linear(in_features = input_size , out_features =256),
    44. nn.ReLU(),
    45. nn.Linear(in_features = 256 , out_features =512),
    46. nn.ReLU(),
    47. nn.Linear(in_features = 512 , out_features =28*28),
    48. nn.Tanh()
    49. )
    50. def forward(self, z):
    51. # z 随机输入[batch, dim]
    52. x = self.net(z)
    53. #[batch, height, width, channel]
    54. #print(x.shape)
    55. x = x.view(-1,28,28,1)
    56. return x
    57. class Discriminator(nn.Module):
    58. '''
    59. 定义鉴别器
    60. 输入:
    61. x: 图片 [batch, height, width, channel]
    62. 输出:
    63. y: 二分类图片的概率: BCELoss 计算交叉熵损失
    64. '''
    65. def __init__(self):
    66. super(Discriminator,self).__init__()
    67. #开始的维度和终止的维度,默认值分别是1和-1
    68. self.flatten = nn.Flatten()
    69. self.net = nn.Sequential(
    70. nn.Linear(in_features = 28*28 , out_features =512),
    71. nn.LeakyReLU(), #负值的时候保留梯度信息
    72. nn.Linear(in_features = 512 , out_features =256),
    73. nn.LeakyReLU(),
    74. nn.Linear(in_features = 256 , out_features =1),
    75. nn.Sigmoid()
    76. )
    77. def forward(self, x):
    78. x = self.flatten(x)
    79. #print(x.shape)
    80. out =self.net(x)
    81. return out
    82. def gen_img_plot(model, epoch, test_input):
    83. out = model(test_input).detach().cpu()
    84. out = out.numpy()
    85. imgs = np.squeeze(out)
    86. fig = plt.figure(figsize=(4,4))
    87. for i in range(out.shape[0]):
    88. plt.subplot(4,4,i+1)
    89. img = (imgs[i]+1)/2.0#[-1,1]
    90. plt.imshow(img)
    91. plt.axis('off')
    92. plt.show()
    93. def train():
    94. #1 初始化参数
    95. device ='cuda' if torch.cuda.is_available() else 'cpu'
    96. #2 加载训练数据
    97. dataloader = loadData()
    98. test_input = torch.randn(16,100,device=device)
    99. #3 超参数
    100. maxIter = 20 #最大训练次数
    101. input_size = 100
    102. batchNum = 16
    103. input_size =100
    104. #4 初始化模型
    105. gen = Generator(100).to(device)
    106. dis = Discriminator().to(device)
    107. #5 优化器,损失函数
    108. d_optim = torch.optim.Adam(dis.parameters(), lr=1e-4)
    109. g_optim = torch.optim.Adam(gen.parameters(),lr=1e-4)
    110. loss_fn = torch.nn.BCELoss()
    111. #6 loss 变化列表
    112. D_loss =[]
    113. G_loss= []
    114. for epoch in range(0,maxIter):
    115. d_epoch_loss = 0.0
    116. g_epoch_loss =0.0
    117. #count = len(dataloader)
    118. for step ,(realImgs, _) in enumerate(dataloader):
    119. realImgs = realImgs.to(device)
    120. random_noise = torch.randn(batchNum, input_size).to(device)
    121. #先训练判别器
    122. d_optim.zero_grad()
    123. real_output = dis(realImgs)
    124. d_real_loss = loss_fn(real_output, torch.ones_like(real_output))
    125. d_real_loss.backward()
    126. #不要训练生成器,所以要生成器detach
    127. fake_img = gen(random_noise)
    128. fake_output = dis(fake_img.detach())
    129. d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))
    130. d_fake_loss.backward()
    131. d_loss = d_real_loss+d_fake_loss
    132. d_optim.step()
    133. #优化生成器
    134. g_optim.zero_grad()
    135. fake_output = dis(fake_img.detach())
    136. g_loss = loss_fn(fake_output, torch.ones_like(fake_output))
    137. g_loss.backward()
    138. g_optim.step()
    139. with torch.no_grad():
    140. d_epoch_loss+= d_loss
    141. g_epoch_loss+= g_loss
    142. count = 16
    143. with torch.no_grad():
    144. d_epoch_loss/=count
    145. g_epoch_loss/=count
    146. D_loss.append(d_epoch_loss)
    147. G_loss.append(g_epoch_loss)
    148. gen_img_plot(gen, epoch, test_input)
    149. print("Epoch: ",epoch)
    150. print("-----finised-----")
    151. if __name__ == "__main__":
    152. train()

    参考:

    10.完整课程简介_哔哩哔哩_bilibili

    理论【PyTorch][chapter 19][李宏毅深度学习]【无监督学习][ GAN]【理论】-CSDN博客

  • 相关阅读:
    高仿拼多多源码/拼单商城系统源码/拼团商城源码
    Redisson入坑篇
    RK3568 安卓12 EC20模块NOCONN没有ip的问题(已解决)
    【微信小程序】button和image组件的基本使用
    基于 JSch 实现服务的自定义监控解决方案
    ubuntu循环登录,无法进入桌面
    【C++基础】2. 标准库
    C system()函数调用删除Windows临时目录下的所有文件
    微信小程序云开发教程——墨刀原型工具入门(页面交互+交互案例教程)
    重学前端——事件循环
  • 原文地址:https://blog.csdn.net/chengxf2/article/details/136393655