• 训练与推理


    1. import os
    2. import torch
    3. import torch.nn as nn
    4. import torch.optim as optim
    5. from torch.utils.tensorboard import SummaryWriter
    6. from datetime import datetime
    7. from dataset import get_train_loader_cifar100, get_val_loader_cifar100
    8. from utils import get_network, WarmUpLR, most_recent_folder, \
    9. most_recent_weights, last_epoch, best_acc_weights
    10. import pyzjr as pz
    11. from pyzjr.dlearn import GPU_INFO
    12. def train_one_epoch(trainingloader,epoch):
    13. time = pz.Timer()
    14. net.train()
    15. for batch_index, (images, labels) in enumerate(trainingloader):
    16. if args.Cuda:
    17. labels = labels.cuda()
    18. images = images.cuda()
    19. optimizer.zero_grad()
    20. outputs = net(images)
    21. loss = loss_function(outputs, labels)
    22. loss.backward()
    23. optimizer.step()
    24. n_iter = (epoch - 1) * len(trainingloader) + batch_index + 1
    25. last_layer = list(net.children())[-1]
    26. for name, para in last_layer.named_parameters():
    27. if 'weight' in name:
    28. writer.add_scalar('LastLayerGradients/grad_norm2_weights', para.grad.norm(), n_iter)
    29. if 'bias' in name:
    30. writer.add_scalar('LastLayerGradients/grad_norm2_bias', para.grad.norm(), n_iter)
    31. print('Training Epoch: {epoch} [{trained_samples}/{total_samples}]\tLoss: {:0.4f}\tLR: {:0.6f}'.format(
    32. loss.item(),
    33. optimizer.param_groups[0]['lr'],
    34. epoch=epoch,
    35. trained_samples=batch_index * args.batch_size + len(images),
    36. total_samples=len(trainingloader.dataset)
    37. ), end='\r', flush=True)
    38. writer.add_scalar('Train/loss', loss.item(), n_iter)
    39. if epoch <= args.warm:
    40. warmup_scheduler.step()
    41. for name, param in net.named_parameters():
    42. layer, attr = os.path.splitext(name)
    43. attr = attr[1:]
    44. writer.add_histogram("{}/{}".format(layer, attr), param, epoch)
    45. time.stop()
    46. print('epoch {} training time consumed: {:.2f}s'.format(epoch, time.total()))
    47. @torch.no_grad()
    48. def eval_training(testloader,epoch=0, tb=True):
    49. time = pz.Timer()
    50. net.eval()
    51. test_loss = 0.0 # cost function error
    52. correct = 0.0
    53. for (images, labels) in testloader:
    54. if args.Cuda:
    55. images = images.cuda()
    56. labels = labels.cuda()
    57. outputs = net(images)
    58. loss = loss_function(outputs, labels)
    59. test_loss += loss.item()
    60. _, preds = outputs.max(1)
    61. correct += preds.eq(labels).sum()
    62. time.stop()
    63. if args.Cuda:
    64. GPU_INFO(
    65. headColor="red",
    66. gpuColor="blue"
    67. )
    68. print('Evaluating Network.....')
    69. print('Test set: Epoch: {}, Average loss: {:.4f}, Accuracy: {:.4f}, Time consumed:{:.2f}s'.format(
    70. epoch,
    71. test_loss / len(testloader.dataset),
    72. correct.float() / len(testloader.dataset),
    73. time.total()
    74. ), end='\r', flush=True)
    75. print()
    76. #add informations to tensorboard
    77. if tb:
    78. writer.add_scalar('Test/Average loss', test_loss / len(test_loader.dataset), epoch)
    79. writer.add_scalar('Test/Accuracy', correct.float() / len(test_loader.dataset), epoch)
    80. return correct.float() / len(test_loader.dataset)
    81. if __name__ == '__main__':
    82. class parser_args():
    83. def __init__(self):
    84. self.net = "vgg16"
    85. self.Cuda = True
    86. self.EPOCHS = 100
    87. self.batch_size = 4
    88. self.warm = 1
    89. self.CHECKPOINT_PATH = 'checkpoint'
    90. self.resume = False
    91. self.lr = 0.01
    92. self.LOG_DIR = "logs"
    93. self.SAVE_EPOCH = 10
    94. self.MILESTONES = [60, 120, 160]
    95. self.DATE_FORMAT = '%A_%d_%B_%Y_%Hh_%Mm_%Ss'
    96. self.TIME_NOW = datetime.now().strftime(self.DATE_FORMAT)
    97. self.CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
    98. self.CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
    99. def _help(self):
    100. stc = {
    101. "log_dir": "存放训练模型.pth的路径",
    102. "Cuda": "是否使用Cuda,如果没有GPU,可以使用CUP,i.e: Cuda=False",
    103. "EPOCHS": "训练的轮次,这里默认就跑100轮",
    104. "batch_size": "批量大小,一般为1,2,4",
    105. "warm": "控制学习率的'热身'或'预热'过程"
    106. }
    107. return stc
    108. args = parser_args()
    109. net = get_network(args)
    110. #data preprocessing:
    111. training_loader = get_train_loader_cifar100(
    112. args.CIFAR100_TRAIN_MEAN,
    113. args.CIFAR100_TRAIN_STD,
    114. num_workers=4,
    115. batch_size=args.batch_size,
    116. shuffle=True
    117. )
    118. test_loader = get_val_loader_cifar100(
    119. args.CIFAR100_TRAIN_MEAN,
    120. args.CIFAR100_TRAIN_STD,
    121. num_workers=4,
    122. batch_size=args.batch_size,
    123. shuffle=True
    124. )
    125. loss_function = nn.CrossEntropyLoss()
    126. optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
    127. train_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.MILESTONES, gamma=0.2) #learning rate decay
    128. iter_per_epoch = len(training_loader)
    129. warmup_scheduler = WarmUpLR(optimizer, iter_per_epoch * args.warm)
    130. if args.resume:
    131. recent_folder = most_recent_folder(os.path.join(args.CHECKPOINT_PATH, args.net), fmt=args.DATE_FORMAT)
    132. if not recent_folder:
    133. raise Exception('no recent folder were found')
    134. checkpoint_path = os.path.join(args.CHECKPOINT_PATH, args.net, recent_folder)
    135. else:
    136. checkpoint_path = os.path.join(args.CHECKPOINT_PATH, args.net, args.TIME_NOW)
    137. if not os.path.exists(args.LOG_DIR):
    138. os.mkdir(args.LOG_DIR)
    139. writerlog_path = pz.logdir(dir=args.LOG_DIR, format=True, prefix=args.net)
    140. writer = SummaryWriter(writerlog_path)
    141. input_tensor = torch.Tensor(1, 3, 32, 32)
    142. if args.Cuda:
    143. input_tensor = input_tensor.cuda()
    144. writer.add_graph(net, input_tensor)
    145. #create checkpoint folder to save model
    146. if not os.path.exists(checkpoint_path):
    147. os.makedirs(checkpoint_path)
    148. checkpoint_path = os.path.join(checkpoint_path, '{net}-{epoch}-{type}.pth')
    149. best_acc = 0.0
    150. if args.resume:
    151. best_weights = best_acc_weights(os.path.join(args.CHECKPOINT_PATH, args.net, recent_folder))
    152. if best_weights:
    153. weights_path = os.path.join(args.CHECKPOINT_PATH, args.net, recent_folder, best_weights)
    154. print('found best acc weights file:{}'.format(weights_path))
    155. print('load best training file to test acc...')
    156. net.load_state_dict(torch.load(weights_path))
    157. best_acc = eval_training(tb=False)
    158. print('best acc is {:0.2f}'.format(best_acc))
    159. recent_weights_file = most_recent_weights(os.path.join(args.CHECKPOINT_PATH, args.net, recent_folder))
    160. if not recent_weights_file:
    161. raise Exception('no recent weights file were found')
    162. weights_path = os.path.join(args.CHECKPOINT_PATH, args.net, recent_folder, recent_weights_file)
    163. print('loading weights file {} to resume training.....'.format(weights_path))
    164. net.load_state_dict(torch.load(weights_path))
    165. resume_epoch = last_epoch(os.path.join(args.CHECKPOINT_PATH, args.net, recent_folder))
    166. for epoch in range(1, args.EPOCHS + 1):
    167. train_one_epoch(training_loader,epoch)
    168. acc = eval_training(test_loader,epoch)
    169. if epoch > args.warm:
    170. train_scheduler.step(epoch)
    171. if args.resume:
    172. if epoch <= resume_epoch:
    173. continue
    174. #start to save best performance model after learning rate decay to 0.01
    175. if epoch > args.MILESTONES[1] and best_acc < acc:
    176. weights_path = checkpoint_path.format(net=args.net, epoch=epoch, type='best')
    177. print('saving weights file to {}'.format(weights_path))
    178. torch.save(net.state_dict(), weights_path)
    179. best_acc = acc
    180. continue
    181. if not epoch % args.SAVE_EPOCH:
    182. weights_path = checkpoint_path.format(net=args.net, epoch=epoch, type='regular')
    183. print('saving weights file to {}'.format(weights_path))
    184. torch.save(net.state_dict(), weights_path)
    185. writer.close()

    工程代码:

    Auorui/Pytorch-Classification-Model-Based-on-CIFAR-100: 基于CIFAR-100的Pytorch分类模型 (github.com)

  • 相关阅读:
    牛客题目——链表的奇偶重排、输出二叉树的右视图、括号生成、字符流中第一个不重复的字符
    怎样判断气门油封有问题?
    一篇文章带你搞定Java封装
    【论文阅读】Hypergraph Convolutional Network for Group Recommendation
    R语言data.table包进行数据分组聚合统计变换(Aggregating transforms)、计算dataframe数据的分组标准差(sd)
    qt开发技巧与三个问题点
    Hadoop核心之MapReduce框架总结Ⅱ
    spark sql重分区
    java jiraClient 针对某个issue增加评论
    倾斜摄影测量实景三维建模ContextCapture Master
  • 原文地址:https://blog.csdn.net/m0_62919535/article/details/134367177