• 训练与推理


    1. import os
    2. import torch
    3. import torch.nn as nn
    4. import torch.optim as optim
    5. from torch.utils.tensorboard import SummaryWriter
    6. from datetime import datetime
    7. from dataset import get_train_loader_cifar100, get_val_loader_cifar100
    8. from utils import get_network, WarmUpLR, most_recent_folder, \
    9. most_recent_weights, last_epoch, best_acc_weights
    10. import pyzjr as pz
    11. from pyzjr.dlearn import GPU_INFO
    12. def train_one_epoch(trainingloader,epoch):
    13. time = pz.Timer()
    14. net.train()
    15. for batch_index, (images, labels) in enumerate(trainingloader):
    16. if args.Cuda:
    17. labels = labels.cuda()
    18. images = images.cuda()
    19. optimizer.zero_grad()
    20. outputs = net(images)
    21. loss = loss_function(outputs, labels)
    22. loss.backward()
    23. optimizer.step()
    24. n_iter = (epoch - 1) * len(trainingloader) + batch_index + 1
    25. last_layer = list(net.children())[-1]
    26. for name, para in last_layer.named_parameters():
    27. if 'weight' in name:
    28. writer.add_scalar('LastLayerGradients/grad_norm2_weights', para.grad.norm(), n_iter)
    29. if 'bias' in name:
    30. writer.add_scalar('LastLayerGradients/grad_norm2_bias', para.grad.norm(), n_iter)
    31. print('Training Epoch: {epoch} [{trained_samples}/{total_samples}]\tLoss: {:0.4f}\tLR: {:0.6f}'.format(
    32. loss.item(),
    33. optimizer.param_groups[0]['lr'],
    34. epoch=epoch,
    35. trained_samples=batch_index * args.batch_size + len(images),
    36. total_samples=len(trainingloader.dataset)
    37. ), end='\r', flush=True)
    38. writer.add_scalar('Train/loss', loss.item(), n_iter)
    39. if epoch <= args.warm:
    40. warmup_scheduler.step()
    41. for name, param in net.named_parameters():
    42. layer, attr = os.path.splitext(name)
    43. attr = attr[1:]
    44. writer.add_histogram("{}/{}".format(layer, attr), param, epoch)
    45. time.stop()
    46. print('epoch {} training time consumed: {:.2f}s'.format(epoch, time.total()))
    47. @torch.no_grad()
    48. def eval_training(testloader,epoch=0, tb=True):
    49. time = pz.Timer()
    50. net.eval()
    51. test_loss = 0.0 # cost function error
    52. correct = 0.0
    53. for (images, labels) in testloader:
    54. if args.Cuda:
    55. images = images.cuda()
    56. labels = labels.cuda()
    57. outputs = net(images)
    58. loss = loss_function(outputs, labels)
    59. test_loss += loss.item()
    60. _, preds = outputs.max(1)
    61. correct += preds.eq(labels).sum()
    62. time.stop()
    63. if args.Cuda:
    64. GPU_INFO(
    65. headColor="red",
    66. gpuColor="blue"
    67. )
    68. print('Evaluating Network.....')
    69. print('Test set: Epoch: {}, Average loss: {:.4f}, Accuracy: {:.4f}, Time consumed:{:.2f}s'.format(
    70. epoch,
    71. test_loss / len(testloader.dataset),
    72. correct.float() / len(testloader.dataset),
    73. time.total()
    74. ), end='\r', flush=True)
    75. print()
    76. #add informations to tensorboard
    77. if tb:
    78. writer.add_scalar('Test/Average loss', test_loss / len(test_loader.dataset), epoch)
    79. writer.add_scalar('Test/Accuracy', correct.float() / len(test_loader.dataset), epoch)
    80. return correct.float() / len(test_loader.dataset)
    81. if __name__ == '__main__':
    82. class parser_args():
    83. def __init__(self):
    84. self.net = "vgg16"
    85. self.Cuda = True
    86. self.EPOCHS = 100
    87. self.batch_size = 4
    88. self.warm = 1
    89. self.CHECKPOINT_PATH = 'checkpoint'
    90. self.resume = False
    91. self.lr = 0.01
    92. self.LOG_DIR = "logs"
    93. self.SAVE_EPOCH = 10
    94. self.MILESTONES = [60, 120, 160]
    95. self.DATE_FORMAT = '%A_%d_%B_%Y_%Hh_%Mm_%Ss'
    96. self.TIME_NOW = datetime.now().strftime(self.DATE_FORMAT)
    97. self.CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
    98. self.CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
    99. def _help(self):
    100. stc = {
    101. "log_dir": "存放训练模型.pth的路径",
    102. "Cuda": "是否使用Cuda,如果没有GPU,可以使用CUP,i.e: Cuda=False",
    103. "EPOCHS": "训练的轮次,这里默认就跑100轮",
    104. "batch_size": "批量大小,一般为1,2,4",
    105. "warm": "控制学习率的'热身'或'预热'过程"
    106. }
    107. return stc
    108. args = parser_args()
    109. net = get_network(args)
    110. #data preprocessing:
    111. training_loader = get_train_loader_cifar100(
    112. args.CIFAR100_TRAIN_MEAN,
    113. args.CIFAR100_TRAIN_STD,
    114. num_workers=4,
    115. batch_size=args.batch_size,
    116. shuffle=True
    117. )
    118. test_loader = get_val_loader_cifar100(
    119. args.CIFAR100_TRAIN_MEAN,
    120. args.CIFAR100_TRAIN_STD,
    121. num_workers=4,
    122. batch_size=args.batch_size,
    123. shuffle=True
    124. )
    125. loss_function = nn.CrossEntropyLoss()
    126. optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
    127. train_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.MILESTONES, gamma=0.2) #learning rate decay
    128. iter_per_epoch = len(training_loader)
    129. warmup_scheduler = WarmUpLR(optimizer, iter_per_epoch * args.warm)
    130. if args.resume:
    131. recent_folder = most_recent_folder(os.path.join(args.CHECKPOINT_PATH, args.net), fmt=args.DATE_FORMAT)
    132. if not recent_folder:
    133. raise Exception('no recent folder were found')
    134. checkpoint_path = os.path.join(args.CHECKPOINT_PATH, args.net, recent_folder)
    135. else:
    136. checkpoint_path = os.path.join(args.CHECKPOINT_PATH, args.net, args.TIME_NOW)
    137. if not os.path.exists(args.LOG_DIR):
    138. os.mkdir(args.LOG_DIR)
    139. writerlog_path = pz.logdir(dir=args.LOG_DIR, format=True, prefix=args.net)
    140. writer = SummaryWriter(writerlog_path)
    141. input_tensor = torch.Tensor(1, 3, 32, 32)
    142. if args.Cuda:
    143. input_tensor = input_tensor.cuda()
    144. writer.add_graph(net, input_tensor)
    145. #create checkpoint folder to save model
    146. if not os.path.exists(checkpoint_path):
    147. os.makedirs(checkpoint_path)
    148. checkpoint_path = os.path.join(checkpoint_path, '{net}-{epoch}-{type}.pth')
    149. best_acc = 0.0
    150. if args.resume:
    151. best_weights = best_acc_weights(os.path.join(args.CHECKPOINT_PATH, args.net, recent_folder))
    152. if best_weights:
    153. weights_path = os.path.join(args.CHECKPOINT_PATH, args.net, recent_folder, best_weights)
    154. print('found best acc weights file:{}'.format(weights_path))
    155. print('load best training file to test acc...')
    156. net.load_state_dict(torch.load(weights_path))
    157. best_acc = eval_training(tb=False)
    158. print('best acc is {:0.2f}'.format(best_acc))
    159. recent_weights_file = most_recent_weights(os.path.join(args.CHECKPOINT_PATH, args.net, recent_folder))
    160. if not recent_weights_file:
    161. raise Exception('no recent weights file were found')
    162. weights_path = os.path.join(args.CHECKPOINT_PATH, args.net, recent_folder, recent_weights_file)
    163. print('loading weights file {} to resume training.....'.format(weights_path))
    164. net.load_state_dict(torch.load(weights_path))
    165. resume_epoch = last_epoch(os.path.join(args.CHECKPOINT_PATH, args.net, recent_folder))
    166. for epoch in range(1, args.EPOCHS + 1):
    167. train_one_epoch(training_loader,epoch)
    168. acc = eval_training(test_loader,epoch)
    169. if epoch > args.warm:
    170. train_scheduler.step(epoch)
    171. if args.resume:
    172. if epoch <= resume_epoch:
    173. continue
    174. #start to save best performance model after learning rate decay to 0.01
    175. if epoch > args.MILESTONES[1] and best_acc < acc:
    176. weights_path = checkpoint_path.format(net=args.net, epoch=epoch, type='best')
    177. print('saving weights file to {}'.format(weights_path))
    178. torch.save(net.state_dict(), weights_path)
    179. best_acc = acc
    180. continue
    181. if not epoch % args.SAVE_EPOCH:
    182. weights_path = checkpoint_path.format(net=args.net, epoch=epoch, type='regular')
    183. print('saving weights file to {}'.format(weights_path))
    184. torch.save(net.state_dict(), weights_path)
    185. writer.close()

    工程代码:

    Auorui/Pytorch-Classification-Model-Based-on-CIFAR-100: 基于CIFAR-100的Pytorch分类模型 (github.com)

  • 相关阅读:
    漏洞危害之一
    mysql隔离级别RR下的行锁、临键锁、间隙锁详解及运用
    【毕业设计】 python flas疫情爬虫可视化
    php获取文件扩展名的三种方法
    日志异常检测准确率低?一文掌握日志指标序列分类
    低代码&无代码,你知道该怎么区分和选择吗?
    sparksql明明插入了但是表里数据是null
    Idea运行支付宝网站支付demo踩坑解决及其测试注意事项
    window和linux的nacos安装
    虚拟dom比真实dom还快吗?90%回答掉坑里了
  • 原文地址:https://blog.csdn.net/m0_62919535/article/details/134367177