• tqdm高级使用方法(类keras进度条)


    简介

    在很多场景,我们希望对一个进度条标识其运行的内容(set_description),同时也希望在进度条中增加一些信息,如模型训练的精度等。本文就将基于tqdm,在实际应用中充实进度条。

    一、简单示例

    from tqdm import tqdm
    
    tq_bar = tqdm(range(10))
    for idx, i in enumerate(tq_bar):
        acc_ = i*10
        loss = 1/((i+1)*10)
        tq_bar.set_description(f'SimpleLoop [{idx+1}]')
        tq_bar.set_postfix(dict(acc=f'{acc_}%', loss=f'{loss:.3f}'))
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 结果
    SimpleLoop [10]: 100%|██████████████████████████████████| 10/10 [00:01<00:00,  8.30it/s, acc=90%, loss=0.010]
    
    • 1

    二、在深度学习训练中使用(pytorch 类似 keras)

    
    import torch
    from torch import nn
    from torch.nn import functional as F
    from torch.optim import AdamW, Adam
    from torch.utils.data import Dataset, TensorDataset, DataLoader
    import torchvision as tv
    from torchvision import transforms
    
    class simpleCNN(nn.Module):
        def __init__(self, input_dim=3, n_class=10):
            super(simpleCNN, self).__init__()
            self.features = nn.Sequential(
                nn.Conv2d(input_dim, 32, kernel_size=7, padding=2,dilation=2, bias=False),
                nn.BatchNorm2d(32),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(3, 2, 1)
            )
            self.clf = nn.Sequential(
                nn.Linear(4608, 128),
                nn.ReLU(inplace=True),
                nn.Dropout(0.2),
                nn.Linear(128, 64),
                nn.ReLU(inplace=True),
                nn.Dropout(0.2),
                nn.Linear(64, n_class)
            )
    
        def forward(self, x):
            out = self.features(x)
            out = out.view(out.size(0), -1)
            return self.clf(out)
    
    
    transform = transforms.Compose([transforms.ToTensor(),#转为tensor
                                    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),#归一化
                                    ])
    dt = tv.datasets.CIFAR10(train=True, download=True, root=r'D:\work\my_project\play_data', transform=transform)
    # dt = tv.datasets.CIFAR10(train=True, download=False, root=r'D:\work\my_project\play_data', transform=transform)
    dt_loader = DataLoader(dt, batch_size=256)
    
    model = simpleCNN(3, 10)
    loss_func = nn.CrossEntropyLoss()
    optm = AdamW(model.parameters(), lr=1e-3)
    
    for ep in range(5):
        one_batch_bar = tqdm(dt_loader)
        one_batch_bar.set_description(f'[ epoch: {ep+1} ]')
        step_counts = 0
        step_loss_sum = 0
        step_right = 0
        step_samples = 0
        for tmp_x, tmp_y in one_batch_bar:
            # forward
            optm.zero_grad()
            step_pred = model(tmp_x)
            step_loss = loss_func(step_pred, tmp_y)
            loss_print = step_loss.detach().numpy()
            step_right_i = (torch.argmax(step_pred, dim=1) == tmp_y).detach().numpy().sum()
            
            # backword
            step_loss.backward()
            optm.step()
            
            # info
            step_counts += 1
            step_loss_sum += loss_print
            step_right += step_right_i
            step_samples += len(tmp_y)
            one_batch_bar.set_postfix(dict(
                loss=f'{step_loss_sum/step_counts:.5f}',
                acc=f'{step_right/step_samples*100:.2f}%'
            ))
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 结果
    [ epoch: 1 ]: 100%|██████████████████████████████████████████████████████| 196/196 [00:39<00:00,  4.95it/s, loss=1.73634, acc=36.93%] 
    [ epoch: 2 ]: 100%|██████████████████████████████████████████████████████| 196/196 [00:38<00:00,  5.13it/s, loss=1.43507, acc=48.46%] 
    [ epoch: 3 ]: 100%|██████████████████████████████████████████████████████| 196/196 [00:45<00:00,  4.34it/s, loss=1.30025, acc=53.85%] 
    [ epoch: 4 ]: 100%|██████████████████████████████████████████████████████| 196/196 [00:37<00:00,  5.28it/s, loss=1.22050, acc=57.06%] 
    [ epoch: 5 ]: 100%|██████████████████████████████████████████████████████| 196/196 [00:42<00:00,  4.65it/s, loss=1.16387, acc=58.78%]
    
    • 1
    • 2
    • 3
    • 4
    • 5
  • 相关阅读:
    初识React -- 一篇文章让你会用react写东西
    Docker安装PostgreSQL
    【Azure 云服务】Azure Cloud Service (Extended Support) 云服务开启诊断日志插件 WAD Extension (Windows Azure Diagnostic) 无法正常工作的原因
    蓝桥杯控制PCF8591
    C语言:详细介绍了六种进程间通讯方式(还有一种socket在主页关于socket的介绍里面有详细介绍,欢迎观看)
    C# 关于托管调试助手 “FatalExecutionEngineError“:“运行时遇到了错误。解决方案
    诊断寻址方式
    为什么要学ib物理课程?
    屏幕分辨率:PC / 手机 屏幕常见分辨率,前端如何适配分辨率
    Flutter didUpdateWidget 的使用问题
  • 原文地址:https://blog.csdn.net/Scc_hy/article/details/126256530