• torch 神经网络模型构建


    点赞收藏关注!
    如需转载,请注明出处!

    torch 的模型搭建,做一下简要的介绍

    神经网络由对数据进行操作的层/模块(layers/modules)组成。
    torch.nn提供构建网络的所有blocks,在PyTorch中的每个modules都继承了nn.Module,可以构建各种复杂的网络结构。通过nn.Module定义神经网络,使用init初始化,对数据的所有操作都在forward()中实现

    class NeuralNetwork(nn.Module):
    def __init__(self):
    	super(NeuralNetwork, self).__init__()
    	self.flatten = nn.Flatten()
    	self.linear_relu_stack = nn.Sequential(
    	nn.Linear(28*28, 512),
    	nn.ReLU(),
    	nn.Linear(512, 512),
    	nn.ReLU(),
    	nn.Linear(512, 10),
    	nn.ReLU()
    	)
    
    #前向传播
    def forward(self, x)
    	x = self.flatten(x)
    	logits = self.linear_relu_stack(x)
    	return logits
    
    
    ##使用示例
    
    #检测是否有GPU可用,若有可以在GPU上训练模型
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print('Using {} device'.format(device))
    model = NeuralNetwork().to(device)
    print(model)
    X = torch.rand(1, 28, 28, device=device)
    logits = model(X)
    pred_probab = nn.Softmax(dim=1)(logits)
    y_pred = pred_probab.argmax(1)
    print(f"Predicted class: {y_pred}")
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32

    为了方便理解,整理了代码中一些函数的意义

    • nn.Flatten()将连续的维度范围展平为张量。一般写在某个神经网络模型之后,用于对神经网络模型的输出进行处理,得到tensor类型的数据。
    • nn.Sequential()是PyTorch中的一个类,它允许用户将多个计算层按照顺序组合成一个模型。在深度学习中,模型可以是由各种不同类型的层组成的,例如卷积层、池化层、全连接层等。nn.Sequential()方法可以将这些层组合在一起,形成一个整体模型。
    • nn.Linear定义一个神经网络的线性层.,Linear其实就是对输入 X执行了一个线性变换的操作。
    • nn.ReLU()模型的激活函数,nn.relu函数是神经网络中常用的激活函数之一,即修正线性单元(Rectified Linear Unit)。ReLU函数的数学表示为f(x) = max(0, x),即输出值等于输入值和0中的较大者。ReLU函数的特点是在输入值大于0时,输出为输入值本身;而在输入值小于等于0时,输出为0。这意味着ReLU会将负值归零,而对正值不做修改

    如有帮助点赞收藏关注!
    如需转载,请注明出处!

  • 相关阅读:
    【LeetCode】接雨水 II [H](堆)
    工业物联网系统下如何实现设备数据采集与设备维护
    Spring基础(四):XML方式实现DI
    Java2 - 数据结构
    如何监控香港服务器的性能
    【电动车优化调度】基于模型预测控制(MPC)的凸优化算法的电动车优化调度(Matlab代码实现)
    亚马逊主图视频的那些事儿
    跨站攻击CSRF实验
    Java的浅拷贝与深拷贝
    Swoole Compiler 加密PHP源代码(简版)
  • 原文地址:https://blog.csdn.net/weixin_42362399/article/details/134530131