• pytorch常用代码片段


    # 注:以下代码仅供参考借鉴,需结合自己数据和模型进行相应的修改调整
    
    • 1

    0、PyTorch介绍

    PyTorch是使用GPU和CPU优化的深度学习张量库,主要用来搭建深度学习框架。与Tensorflow的静态计算图不同,PyTorch的计算图是动态的,可以根据计算需要实时改变计算图。

    1、数据加载

    自己编写 dataload,主要是以下三部分
    1)init 中读取文件内容
    2)len 实现 len(dataset) 返还数据集的尺寸
    3)getitem 用来获取一些索引数据,例如 使用dataset[i] 获得第i个样本

    class MyDataLoad():
    	def __init__(self):
    		pass
    	def __getitem__(self,item):
    		pass
    	def __len__(self):
    		return self._len 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    2、模型搭建

    ''' model--mlp '''
    class MLP(nn.Module):
        def __init__(self):
            super(MLP,self).__init__()
            self.flat = nn.Flatten()
            self.fc1 = nn.Linear(48*10,24)
            
        def forward(self,din):
            din = din.to(torch.float32)
            din = self.flat(din)
            dout = self.fc1(din)
            return dout
    py_model = MLP()
    # 该库可以像 keras 的 summary 一样打印模型架构
    from torchkeras import summary
    summary(py_model, input_shape=(48,10)) 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    3、模型训练

    	# 加载数据
        train_data = MyDataLoad("train")
        train_loader = DataLoader(train_data, batch_size=64, shuffle=False, num_workers=0)
        model = MLP()
        
        # 开始训练
        criterion = nn.MSELoss()
        optimizer = optim.Adam(params=model.parameters(), lr = 0.01)
    
        # train model
        model.train()
        Epoch = 10
        for epoch in range(Epoch):
            epoch_loss = 0.0
            start_time = time.time()
            ct = 0
            for data in train_loader:
                ct += 1
                model.zero_grad()
                predict_value = model(data["flow_x"])
                predict_value = predict_value.to(torch.float32)
                data["flow_y"] = data["flow_y"].to(torch.float32)
                data["flow_y"] = data["flow_y"].reshape(-1,24)
                predict_value = predict_value.reshape(-1, 24)
    
                loss = criterion(predict_value, data["flow_y"])
                epoch_loss = epoch_loss + loss.item()
    
                loss.requires_grad_(True)
                loss.backward()
                optimizer.step()
            end_time = time.time()
            print("Epoch: {:04d}, Loss: {:02.4f}, Time: {:02.2f} mins".format(epoch, 1000 * epoch_loss / len(train_data),
                                                                              (end_time - start_time) / 60))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34

    4、模型评估

    	# test model
        test_data = XxDataLoad("test")
        test_loader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=0)
        model.eval()
        with torch.no_grad():
            Target = np.zeros([2,1,1])
            pre_model = []
    
            total_loss = 0.0
            ct = 0
            for data in test_loader:
                predict_value = model(data["flow_x"]) 
                predict_value = predict_value.to(torch.float32)
    
                for i in range(len(predict_value)):
                    ct += 1
                    tmp = predict_value[i].numpy()
                    tmp = np.reshape(tmp, (1,24))
                    pre_model.append(tmp)
                    # print(len(predict_value), np.shape(pre_model[0]))
                    # exit()
                    if ct==7000:
                        print(len(predict_value), np.shape(pre_model))
                        break
    
                data["flow_y"] = data["flow_y"].to(torch.float32)
                data["flow_y"] = data["flow_y"].reshape(-1, 24)
                predict_value = predict_value.reshape(-1, 24)
    
                loss = criterion(predict_value, data["flow_y"])
                total_loss += loss.item()
    
            print("Test Loss: {:02.4f}".format(1000 * total_loss / len(test_data)))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33

    5、模型保存

    1)只保存模型参数字典

    torch.save(model.state_dict(), PATH)
    
    • 1

    2)保存整个模型

    torch.save(the_model, PATH)
    
    • 1

    6、模型加载

    1)只保存模型参数字典

    model = MyModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH))
    
    • 1
    • 2

    2)保存整个模型

    model = torch.load(PATH)
    
    • 1

    7、Trick

    1)拼接多个tensor:torch.cat((A,B),axis)
    axis=0为按行拼接;axis=1为按列拼接
    2)python选取tensor某一维_Pytorch的Tensor操作(1)
    x = [[1,2,3],[4,5,6]] # (2,3)
    res = x[:, 0] # (2,1) res=[[1],[4]]
    3)torch.mul, mm, bmm, matmul, spmm 等矩阵乘法
    4)torch.eye(n) 创建对角矩阵;
    5)x.transpose(0,1) 将0维和1维元素进行转置
    6)numpy转tensor b=torch.tensor(a)
    7)转float32 x=x.to(torch.float32)

  • 相关阅读:
    Qt开发技术:Q3D图表开发笔记(三):Q3DSurface三维曲面图介绍、Demo以及代码详解
    面向对象实验四类的继承
    最长公共子串问题
    python opencv比较图片相似度
    uniapp tabBar app页面滚动闪屏的问题
    【Linux】感性认识冯诺依曼体系结构和操作系统
    vulnhub靶场之DOUBLETROUBLE: 1
    Ajax的简单使用
    Tcp三次握手和四次挥手
    Spring中如何在一个Bean中注入一个内部Bean呢?
  • 原文地址:https://blog.csdn.net/qq_44391957/article/details/127440814