• 深度学习pytorch训练代码模板(个人习惯)


    深度学习pytorch训练代码模板(个人习惯)

    来源:https://zhuanlan.zhihu.com/p/396666255

    从参数定义,到网络模型定义,再到训练步骤,验证步骤,测试步骤,总结了一套较为直观的模板。目录如下:
    导入包以及设置随机种子
    以类的方式定义超参数
    定义自己的模型
    定义早停类(此步骤可以省略)
    定义自己的数据集Dataset,DataLoader
    实例化模型,设置loss,优化器等
    开始训练以及调整lr
    绘图
    预测
    一、导入包以及设置随机种子

    import numpy as np
    import torch
    import torch.nn as nn
    import numpy as np
    import pandas as pd
    from torch.utils.data import DataLoader, Dataset
    from sklearn.model_selection import train_test_split
    import matplotlib.pyplot as plt
    
    import random
    seed = 42
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14

    二、以类的方式定义超参数

    class argparse():
        pass
    
    args = argparse()
    args.epochs, args.learning_rate, args.patience = [30, 0.001, 4]
    args.hidden_size, args.input_size= [40, 30]
    args.device, = [torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),]
    三、定义自己的模型
    class Your_model(nn.Module):
        def __init__(self):
            super(Your_model, self).__init__()
            pass
            
        def forward(self,x):
            pass
            return x
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    四、定义早停类(此步骤可以省略)

    class EarlyStopping():
        def __init__(self,patience=7,verbose=False,delta=0):
            self.patience = patience
            self.verbose = verbose
            self.counter = 0
            self.best_score = None
            self.early_stop = False
            self.val_loss_min = np.Inf
            self.delta = delta
        def __call__(self,val_loss,model,path):
            print("val_loss={}".format(val_loss))
            score = -val_loss
            if self.best_score is None:
                self.best_score = score
                self.save_checkpoint(val_loss,model,path)
            elif score < self.best_score+self.delta:
                self.counter+=1
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
                if self.counter>=self.patience:
                    self.early_stop = True
            else:
                self.best_score = score
                self.save_checkpoint(val_loss,model,path)
                self.counter = 0
        def save_checkpoint(self,val_loss,model,path):
            if self.verbose:
                print(
                    f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
            torch.save(model.state_dict(), path+'/'+'model_checkpoint.pth')
            self.val_loss_min = val_loss
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30

    五、定义自己的数据集Dataset,DataLoader

    class Dataset_name(Dataset):
        def __init__(self, flag='train'):
            assert flag in ['train', 'test', 'valid']
            self.flag = flag
            self.__load_data__()
    
        def __getitem__(self, index):
            pass
        def __len__(self):
            pass
    
        def __load_data__(self, csv_paths: list):
            pass
            print(
                "train_X.shape:{}\ntrain_Y.shape:{}\nvalid_X.shape:{}\nvalid_Y.shape:{}\n"
                .format(self.train_X.shape, self.train_Y.shape, self.valid_X.shape, self.valid_Y.shape))
    
    train_dataset = Dataset_name(flag='train')
    train_dataloader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
    valid_dataset = Dataset_name(flag='valid')
    valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=64, shuffle=True)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21

    六、实例化模型,设置loss,优化器等

    model = Your_model().to(args.device)
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(Your_model.parameters(),lr=args.learning_rate)
    
    train_loss = []
    valid_loss = []
    train_epochs_loss = []
    valid_epochs_loss = []
    
    early_stopping = EarlyStopping(patience=args.patience,verbose=True)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    七、开始训练以及调整lr

    for epoch in range(args.epochs):
        Your_model.train()
        train_epoch_loss = []
        for idx,(data_x,data_y) in enumerate(train_dataloader,0):
            data_x = data_x.to(torch.float32).to(args.device)
            data_y = data_y.to(torch.float32).to(args.device)
            outputs = Your_model(data_x)
            optimizer.zero_grad()
            loss = criterion(data_y,outputs)
            loss.backward()
            optimizer.step()
            train_epoch_loss.append(loss.item())
            train_loss.append(loss.item())
            if idx%(len(train_dataloader)//2)==0:
                print("epoch={}/{},{}/{}of train, loss={}".format(
                    epoch, args.epochs, idx, len(train_dataloader),loss.item()))
        train_epochs_loss.append(np.average(train_epoch_loss))
        
        #=====================valid============================
        Your_model.eval()
        valid_epoch_loss = []
        for idx,(data_x,data_y) in enumerate(valid_dataloader,0):
            data_x = data_x.to(torch.float32).to(args.device)
            data_y = data_y.to(torch.float32).to(args.device)
            outputs = Your_model(data_x)
            loss = criterion(outputs,data_y)
            valid_epoch_loss.append(loss.item())
            valid_loss.append(loss.item())
        valid_epochs_loss.append(np.average(valid_epoch_loss))
        #==================early stopping======================
        early_stopping(valid_epochs_loss[-1],model=Your_model,path=r'c:\\your_model_to_save')
        if early_stopping.early_stop:
            print("Early stopping")
            break
        #====================adjust lr========================
        lr_adjust = {
                2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6,
                10: 5e-7, 15: 1e-7, 20: 5e-8
            }
        if epoch in lr_adjust.keys():
            lr = lr_adjust[epoch]
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            print('Updating learning rate to {}'.format(lr))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44

    八、绘图

    plt.figure(figsize=(12,4))
    plt.subplot(121)
    plt.plot(train_loss[:])
    plt.title("train_loss")
    plt.subplot(122)
    plt.plot(train_epochs_loss[1:],'-o',label="train_loss")
    plt.plot(valid_epochs_loss[1:],'-o',label="valid_loss")
    plt.title("epochs_loss")
    plt.legend()
    plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    九、预测

    此处可定义一个预测集的Dataloader。也可以直接将你的预测数据reshape,添加batch_size=1

    Your_model.eval()
    predict = Your_model(data)
    
    • 1
    • 2

    【项目推荐】

    面向小白的顶会论文核心代码库:https://github.com/xmu-xiaoma666/External-Attention-pytorch

    面向小白的YOLO目标检测库:https://github.com/iscyy/yoloair

    面向小白的顶刊顶会的论文解析:https://github.com/xmu-xiaoma666/FightingCV-Paper-Reading

    “点个在看,月薪十万!”

    “学会点赞,身价千万!”

  • 相关阅读:
    阿里云安全恶意程序检测(速通二)
    【自用】深度学习工作站安装ubuntu 18.04 LTS系统
    行业调研:2022年养老保险市场现状及前景分析
    019 基于Spring Boot的教务管理系统、学生管理系统、课表查询系统
    高低JDK版本中JNDI注入(上)
    基于Java+vue前后端分离高校社团管理系统设计实现(源码+lw+部署文档+讲解等)
    [深度学习] Python人脸识别库Deepface使用教程
    数据结构:树的概念和结构
    class12:async 和 await
    linux 下设置IP,mac,netmask
  • 原文地址:https://blog.csdn.net/Jason_android98/article/details/126883547