• 深度学习笔记_1、定义神经网络


     1、使用了PyTorch的nn.Module类来定义神经网络模型;使用nn.Linear来创建全连接层。(CPU)

    1. import torch.nn as nn
    2. import torch.nn.functional as F
    3. from torchsummary import summary
    4. # 定义神经网络模型
    5. class Net(nn.Module):
    6. def __init__(self):
    7. super(Net, self).__init__()
    8. self.fc1 = nn.Linear(in_features=250, out_features=100, bias=True) # 输入层到隐藏层1,具有250个输入特征和100个神经元
    9. self.fc2 = nn.Linear(100, 50) # 隐藏层2,具有10050个神经元
    10. self.fc3 = nn.Linear(50, 25) # 隐藏层3,具有5025个神经元
    11. self.fc4 = nn.Linear(25, 10) # 隐藏层4,具有2510个神经元
    12. self.fc5 = nn.Linear(10, 2) # 输出层,具有102个神经元,用于二分类任务
    13. # 前向传播函数
    14. def forward(self, x):
    15. x = x.view(-1, 250) # 将输入数据展平成一维张量
    16. x = F.relu(self.fc1(x)) # 使用ReLU激活函数传递到隐藏层1
    17. x = F.relu(self.fc2(x)) # 使用ReLU激活函数传递到隐藏层2
    18. x = F.relu(self.fc3(x)) # 使用ReLU激活函数传递到隐藏层3
    19. x = F.relu(self.fc4(x)) # 使用ReLU激活函数传递到隐藏层4
    20. x = self.fc5(x) # 输出层,没有显式激活函数
    21. return x
    22. if __name__ == '__main__':
    23. print(Net())
    24. model = Net()
    25. summary(model, (250,)) # 打印模型摘要信息,输入大小为(250,)

     

    2、GPU版本

    1. import torch
    2. import torch.nn as nn
    3. import torch.nn.functional as F
    4. from torchsummary import summary
    5. class Net(nn.Module):
    6. def __init__(self):
    7. super(Net, self).__init__()
    8. self.fc1 = nn.Linear(784, 100).to(device='cuda:0')
    9. self.fc2 = nn.Linear(100, 50).to(device='cuda:0')
    10. self.fc3 = nn.Linear(50, 25).to(device='cuda:0')
    11. self.fc4 = nn.Linear(25, 10).to(device='cuda:0')
    12. def forward(self, x):
    13. x = F.relu(self.fc1(x))
    14. x = F.relu(self.fc2(x))
    15. x = F.relu(self.fc3(x))
    16. x = F.relu(self.fc4(x))
    17. return x
    18. device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    19. model = Net().to(device)
    20. input_data = torch.randn(784, 100).to(device)
    21. summary(model, (784, ))

  • 相关阅读:
    如何正确操作封箱机
    Python邮件发送程序代码
    nuxt 客户端路由跳转时同一组件重新挂载导致mounted 生命周期重复执行问题解决
    String的使用
    JavaScript apply、call、bind 函数详解
    快速入门Git
    【力扣】128. 最长连续序列 <哈希、Set>
    vue 模态框场景 阻止事件冒泡
    设计模式之(8)——代理模式
    基于chatgpt-on-wechat搭建个人知识库微信群聊机器人
  • 原文地址:https://blog.csdn.net/cfy2401926342/article/details/133436799