• PyTorch搭建卷积神经网络(CNN)进行视频行为识别(附源码和数据集)


    需要数据集和源码请点赞关注收藏后评论区留下QQ邮箱~~~

    一、行为识别简介

    行为识别是视频理解中的一项基础任务,它可以从视频中提取语义信息,进而可以为其他任务如行为检测,行为定位等提供通用的视频表征

    现有的视频行为数据集大致可以划分为两种类型

    1:场景相关数据集  这一类的数据集场景提供了较多的语义信息 仅仅通过单帧图像便能很好的判断对应的行为 

    2:时序相关数据集  这一类数据集对时间关系要求很高,需要足够多帧图像才能准确的识别视频中的行为。

    例如骑马的例子就与场景高度相关,马和草地给出了足够多的语义信息

    但是打开柜子就与时间高度相关,如果反转时序甚至容易认为在关闭柜子

     如下图

     

     二、数据准备

    数据的准备包括对视频的抽帧处理,具体原理此处不再赘述

    大家可自行前往官网下载数据集

    视频行为识别数据集

    三、模型搭建与训练

    在介绍模型的搭建与训练之外,需要先了解的命令行参数,还有无名的必填参数dataset以及modality。前者用于选择数据集,后者用于确定数据集类型 是RGB图像还是Flow光流图像

    过程比较繁琐 此处不再赘述

    效果如下图

    最终会得到如下的热力图,从红色到黄色到绿色到蓝色,网络的关注度从大到小,可以看到模块可以很好地定位到运动发生的时空区域 

    四、代码 

    项目结构如下

     

    main函数代码

    1. import os
    2. import time
    3. import shutil
    4. import torch.nn.parallel
    5. imd_norm_
    6. from ops.dataset import TSNDataSet
    7. from ops.models import TSN
    8. from ops.transforms import *
    9. from opts import parser
    10. from ops import dataset_config
    11. from ops.utils import AverageMeter, accuracy
    12. from ops.temporal_shift import make_temporal_pool
    13. from tensorboardX import SummaryWriter
    14. best_prec1 = 0
    15. def main():
    16. global args, best_prec1
    17. args = parser.parse_args()
    18. num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset,
    19. args.modality)
    20. full_arch_name = args.arch
    21. if args.shift:
    22. full_arch_name += '_shift{}_{}'.format(args.shift_div, args.shift_place)
    23. if args.temporal_pool:
    24. full_arch_name += '_tpool'
    25. args.store_name = '_'.join(
    26. ['TSM', args.dataset, args.modality, full_arch_name, args.consensus_type, 'segment%d' % args.num_segments,
    27. 'e{}'.format(args.epochs)])
    28. args.store_name += '_nl'
    29. if args.suffix is not None:
    30. args.store_name += '_{}'.format(args.suffix)
    31. print('storing name: ' + args.store_name)
    32. check_rootfolders()
    33. model = TSN(num_class, args.num_segments, args.modality,
    34. base_model=args.arch,
    35. consensus_type=args.consensus_type,
    36. dropout=args.dropout,
    37. img_feature_dim=args.img_feature_dim,
    38. partial_bn=not args.no_partialbn,
    39. pretrain=args.pretrain,
    40. is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place,
    41. fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
    42. temporal_pool=args.temporal_pool,
    43. non_local=args.non_local)
    44. crop_size = model.crop_size
    45. scale_size = model.scale_size
    46. input_mean = model.input_mean
    47. in else True)
    48. model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()
    49. optimizer = torch.optim.SGD(policies,
    50. args.lr,
    51. momentum=args.momentum,
    52. weight_decay=args.weight_decay)
    53. if args.resume:
    54. if args.temporal_pool: # early temporal pool so that we can load the state_dict
    55. make_temporal_pool(model.module.base_model, args.num_segments)
    56. if os.path.isfile(args.resume):
    57. print(("=> loading checkpoint '{}'".format(args.resume)))
    58. checkpoint = torch.load(args.resume)
    59. args.start_epoch = checkpoint['epoch']
    60. best_prec1 = checkpoint['best_prec1']
    61. model.load_state_dict(checkpoint['state_dict'])
    62. optimizer.load_state_dict(checkpoint['optimizer'])
    63. print(("=> loaded checkpoint '{}' (epoch {})"
    64. .format(args.evaluate, checkpoint['epoch'])))
    65. else:
    66. print(("=> no checkpoint found at '{}'".format(args.resume)))
    67. ate_dict']
    68. model_dict = model.state_dict()
    69. replace_dict = []
    70. for k, v in sd.items():
    71. if k not in model_dict and k.replace('.net', '') in model_dict:
    72. print('=> Load after remove .net: ', k)
    73. replace_dict.append((k, k.replace('.net', '')))
    74. for k, v in model_dict.items():
    75. if k not in sd and k.replace('.net', '') in sd:
    76. print('=> Load after adding .net: ', k)
    77. replace_dict.append((k.replace('.net', ''), k))
    78. for k, k_new in replace_dict:
    79. sd[k_new] = sd.pop(k)
    80. keys1 = set(list(sd.keys()))
    81. keys2 = set(list(model_dict.keys()))
    82. set_diff = (keys1 - keys2) | (keys2 - keys1)
    83. print('#### Notice: keys that failed to load: {}'.format(set_diff))
    84. if args.dataset not in args.tune_from: # new dataset
    85. print('=> New dataset, do not load fc weights')
    86. sd = {k: v for k, v in sd.items() if 'fc' not in k}
    87. if te_dict(model_dict)
    88. if args.temporal_pool and not args.resume:
    89. make_temporal_pool(model.module.base_model, args.num_segments)
    90. cudnn.benchmark = True
    91. # Data loading code
    92. if args.modality != 'RGBDiff':
    93. normalize = GroupNormalize(input_mean, input_std)
    94. else:
    95. normalize = IdentityTransform()
    96. if args.modality == 'RGB':
    97. data_length = 1
    98. elif args.modality in ['Flow', 'RGBDiff']:
    99. data_length = 5
    100. train_loader = torch.utils.data.DataLoader(
    101. TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments,
    102. new_length=data_length,
    103. modality=args.modality,
    104. image_tmpl=prefix,
    105. transform=torchvision.transforms.Compose([
    106. train_augmentation,
    107. Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
    108. ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
    109. normalize,
    110. ]), dense_sample=args.dense_sample),
    111. batch_size=args.batch_size, shuffle=True,
    112. num_workers=args.workers, pin_memory=True,
    113. drop_last=True) # prevent something not % n_GPU
    114. val_loader = torch.utils.data.DataLoader(
    115. TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments,
    116. new_length=data_length,
    117. modality=args.modality,
    118. image_tmpl=prefix,
    119. random_shift=False,
    120. transform=torchvision.transforms.Compose([
    121. GroupScale(int(scale_size)),
    122. GroupCenterCrop(crop_size),
    123. Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
    124. ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
    125. normalize,
    126. ]), dense_sample=args.dense_sample),
    127. batch_size=args.batch_size, shuffle=False,
    128. num_workers=args.workers, pin_memory=True)
    129. # define loss function (criterion) and optimizer
    130. if args.loss_type == 'nll':
    131. criterion = torch.nn.CrossEntropyLoss().cuda()
    132. else:
    133. raise ValueError("Unknown loss type")
    134. for group in policies:
    135. print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
    136. group['name'], len(group['params']), group['lr_mult'], group['decay_mult'])))
    137. if args.evaluate:
    138. validate(val_loader, model, criterion, 0)
    139. return
    140. log_training = open(os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
    141. with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f:
    142. f.write(str(args))
    143. tf_writer = SummaryWriter(log_dir=os.path.join(args.root_log, args.store_name))
    144. for epoch in range(args.start_epoch, args.epochs):
    145. adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)
    146. # train for one epoch
    147. train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer)
    148. # evaluate on validation set
    149. if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
    150. prec1 = validate(val_loader, model, criterion, epoch, log_training, tf_writer)
    151. # remember best prec@1 and save checkpoint
    152. is_best = prec1 > best_prec1
    153. best_prec1 = max(prec1, best_prec1)
    154. tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch)
    155. output_best = 'Best Prec@1: %.3f\n' % (best_prec1)
    156. print(output_best)
    157. log_training.write(output_best + '\n')
    158. log_training.flush()
    159. save_checkpoint({
    160. 'epoch': epoch + 1,
    161. 'arch': args.arch,
    162. 'state_dict': model.state_dict(),
    163. 'optimizer': optimizer.state_dict(),
    164. 'best_prec1': best_prec1,
    165. }, is_best)
    166. def train(train_loader, model, criterion, optimizer, epoch, log, tf_writer):
    167. batch_time = AverageMeter()
    168. data_time = AverageMeter()
    169. losses = AverageMeter()
    170. top1 = AverageMeter()
    171. top5 = AverageMeter()
    172. if args.no_partialbn:
    173. model.module.partialBN(False)
    174. else:
    175. model.module.partialBN(True)
    176. # switch to train mode
    177. model.train()
    178. end = time.time()
    179. for i, (input, target) in enumerate(train_loader):
    180. # measure data loading time
    181. data_time.update(time.time() - end)
    182. target = target.cuda()
    183. input_var = torch.autograd.Variable(input)
    184. target_var = torch.autograd.Variable(target)
    185. # compute output
    186. output = model(input_var)
    187. loss = criterion(output, target_var)
    188. # measure accuracy and record loss
    189. prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
    190. losses.update(loss.item(), input.size(0))
    191. top1.update(prec1.item(), input.size(0))
    192. top5.update(prec5.item(), input.size(0))
    193. # compute gradient and do SGD step
    194. loss.backward()
    195. if args.clip_gradient is not None:
    196. total_norm = clip_grad_norm_(model.parameters(), args.clip_gradient)
    197. optimizer.step()
    198. optimizer.zero_grad()
    199. # measure elapsed time
    200. batch_time.update(time.time() - end)
    201. end = time.time()
    202. if i % args.print_freq == 0:
    203. output = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t'
    204. 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
    205. 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
    206. 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
    207. 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
    208. 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
    209. epoch, i, len(train_loader), batch_time=batch_time,
    210. data_time=data_time, loss=losses, top1=top1, top5=top5, lr=optimizer.param_groups[-1]['lr'] * 0.1)) # TODO
    211. print(output)
    212. log.write(output + '\n')
    213. log.flush()
    214. tf_writer.add_scalar('loss/train', losses.avg, epoch)
    215. tf_writer.add_scalar('acc/train_top1', top1.avg, epoch)
    216. tf_writer.add_scalar('acc/train_top5', top5.avg, epoch)
    217. tf_writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch)
    218. def validate(val_loader, model, criterion, epoch, log=None, tf_writer=None):
    219. batch_time = AverageMeter()
    220. losses = AverageMeter()
    221. top1 = AverageMeter()
    222. top5 = AverageMeter()
    223. # switch to evaluate mode
    224. model.eval()
    225. end = time.time()
    226. with torch.no_grad():
    227. for i, (input, target) in enumerate(val_loader):
    228. target = target.cuda()
    229. # compute output
    230. output = model(input)
    231. loss = criterion(output, target)
    232. # measure accuracy and record loss
    233. prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
    234. losses.update(loss.item(), input.size(0))
    235. top1.update(prec1.item(), input.size(0))
    236. top5.update(prec5.item(), input.size(0))
    237. # measure elapsed time
    238. batch_time.update(time.time() - end)
    239. end = time.time()
    240. if i % args.print_freq == 0:
    241. output = ('Test: [{0}/{1}]\t'
    242. 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
    243. 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
    244. 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
    245. 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
    246. i, len(val_loader), batch_time=batch_time, loss=losses,
    247. top1=top1, top5=top5))
    248. print(output)
    249. if log is not None:
    250. log.write(output + '\n')
    251. log.flush()
    252. output = ('Testing Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f}'
    253. .format(top1=top1, top5=top5, loss=losses))
    254. print(output)
    255. if log is not None:
    256. log.write(output + '\n')
    257. log.flush()
    258. if tf_writer is not None:
    259. tf_writer.add_scalar('loss/test', losses.avg, epoch)
    260. tf_writer.add_scalar('acc/test_top1', top1.avg, epoch)
    261. tf_writer.add_scalar('acc/test_top5', top5.avg, epoch)
    262. return top1.avg
    263. def save_checkpoint(state, is_best):
    264. filename = '%s/%s/ckpt.pth.tar' % (args.root_model, args.store_name)
    265. torch.save(state, filename)
    266. if is_best:
    267. shutil.copyfile(filename, filename.replace('pth.tar', 'best.pth.tar'))
    268. def adjust_learning_rate(optimizer, epoch, lr_type, lr_steps):
    269. """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    270. if lr_type == 'step':
    271. decay = 0.1 ** (sum(epoch >= np.array(lr_steps)))
    272. lr = args.lr * decay
    273. decay = args.weight_decay
    274. elif lr_type == 'cos':
    275. import math
    276. lr = 0.5 * args.lr * (1 + math.cos(math.pi * epoch / args.epochs))
    277. decay = args.weight_decay
    278. else:
    279. raise NotImplementedError
    280. for param_group in optimizer.param_groups:
    281. param_group['lr'] = lr * param_group['lr_mult']
    282. param_group['weight_decay'] = decay * param_group['decay_mult']
    283. def check_rootfolders():
    284. """Create log and model folder"""
    285. folders_util = [args.root_log, args.root_model,
    286. os.path.join(args.root_log, args.store_name),
    287. os.path.join(args.root_model, args.store_name)]
    288. for folder in folders_util:
    289. if not os.path.exists(folder):
    290. print('creating folder ' + folder)
    291. os.mkdir(folder)
    292. if __name__ == '__main__':
    293. main()

    opts类代码如下

    1. #这里下面的参数应该要自行输入
    2. import argparse
    3. parser = argparse.ArgumentParser(description="PyTorch implementation of Temporal Segment Networks")
    4. parser.add_argument('dataset', default="")
    5. parser.add_argument('modality', default="RGB", choices=['RGB', 'Flow'])
    6. parser.add_argument('--train_list', type=str, default="")
    7. parser.add_argument('--val_list', type=str, default="")
    8. parser.add_argument('--root_path', type=str, default="")
    9. parser.add_argument('--store_name', type=str, default="")
    10. # ========================= Model Configs ==========================
    11. parser.add_argument('--arch', type=str, default="BNInception")
    12. parser.add_argument('--num_segments', type=int, default=3)
    13. parser.add_argument('--consensus_type', type=str, default='avg')
    14. parser.add_argument('--k', type=int, default=3)
    15. parser.add_argument('--dropout', '--do', default=0.5, type=float,
    16. metavar='DO', help='dropout ratio (default: 0.5)')
    17. parser.add_argument('--loss_type', type=str, default="nll",
    18. choices=['nll'])
    19. parser.add_argument('--img_feature_dim', default=256, type=int, help="the feature dimension for each frame")
    20. parser.add_argument('--suffix', type=str, default=None)
    21. parser.add_argument('--pretrain', type=str, default='imagenet')
    22. parser.add_argument('--tune_from', type=str, default=None, help='fine-tune from checkpoint')
    23. # ========================= Learning Configs ==========================
    24. parser.add_argument('--epochs', default=120, type=int, metavar='N',
    25. help='number of total epochs to run')
    26. parser.add_argument('-b', '--batch-size', default=128, type=int,
    27. metavar='N', help='mini-batch size (default: 256)')
    28. parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
    29. metavar='LR', help='initial learning rate')
    30. parser.add_argument('--lr_type', default='step', type=str,
    31. metavar='LRtype', help='learning rate type')
    32. parser.add_argument('--lr_steps', default=[50, 100], type=float, nargs="+",
    33. metavar='LRSteps', help='epochs to decay learning rate by 10')
    34. parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
    35. help='momentum')
    36. parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
    37. metavar='W', help='weight decay (default: 5e-4)')
    38. parser.add_argument('--clip-gradient', '--gd', default=None, type=float,
    39. metavar='W', help='gradient norm clipping (default: disabled)')
    40. parser.add_argument('--no_partialbn', '--npb', default=False, action="store_true")
    41. # ========================= Monitor Configs ==========================
    42. parser.add_argument('--print-freq', '-p', default=20, type=int,
    43. metavar='N', help='print frequency (default: 10)')
    44. parser.add_argument('--eval-freq', '-ef', default=5, type=int,
    45. metavar='N', help='evaluation frequency (default: 5)')
    46. # ========================= Runtime Configs ==========================
    47. parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
    48. help='number of data loading workers (default: 8)')
    49. parser.add_argument('--resume', default='', type=str, metavar='PATH',
    50. help='path to latest checkpoint (default: none)')
    51. parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
    52. help='evaluate model on validation set')
    53. parser.add_argument('--snapshot_pref', type=str, default="")
    54. parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
    55. help='manual epoch number (useful on restarts)')
    56. parser.add_argument('--gpus', nargs='+', type=int, default=None)
    57. parser.add_argument('--flow_prefix', default="", type=str)
    58. parser.add_argument('--root_log',type=str, default='log')
    59. parser.add_argument('--root_model', type=str, default='checkpoint')
    60. parser.add_argument('--shift', default=False, action="store_true", help='use shift for models')
    61. parser.add_argument('--shift_div', default=8, type=int, help='number of div for shift (default: 8)')
    62. parser.add_argument('--shift_place', default='blockres', type=str, help='place for shift (default: stageres)')
    63. parser.add_argument('--temporal_pool', default=False, action="store_true", help='add temporal pooling')
    64. parser.add_argument('--non_local', default=False, action="store_true", help='add non local block')
    65. parser.add_argument('--dense_sample', default=False, action="store_true", help='use dense sample for video dataset')

    test_models类代码如下

    1. # Notice that this file has been modified to support ensemble testing
    2. from ops.transforms import *
    3. from ops import dataset_config
    4. from torch.nn import functional as F
    5. # options
    6. parser = argparse.ArgumentParser(description="TSM testing on the full validation set")
    7. parser.add_argument('dataset', type=str)
    8. # may contain splits
    9. pars
    10. parser.add_argument('--test_crops', type=int, default=1)
    11. parser.add_argument('--coeff', type=str, default=None)
    12. parser.add_argument('--batch_size', type=int, default=1)
    13. parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
    14. help='number of data loading workers (default: 8)')
    15. # for true test
    16. parser.add_argument('--test_list', type=str, default=None)
    17. parser.add_argument('--csv_file', type=str, default=None)
    18. parser.add_argument('--softmax', default=False, action="store_true", help='use softmax')
    19. parser.add_argument('--max_num', type=int, default=-1)
    20. parser.add_argument('--input_size', type=int, default=224)
    21. parser.add_argument('--crop_fusion_type', type=str, default='avg')
    22. parser.add_argument('--gpus', nargs='+', type=int, default=None)
    23. parser.add_argument('--img_feature_dim',type=int, default=256)
    24. parser.add_argument('--num_set_segments',type=int, default=1,help='TODO: select multiply set of n-frames from a video')
    25. parser.add_argument('--pretrain', type=str, default='imagenet')
    26. args = parser.parse_args()
    27. class AverageMeter(object):
    28. """Computes and stores the average and current value"""
    29. def __init__(self):
    30. self.reset()
    31. def reset(self):
    32. self.val = 0
    33. self.avg = 0
    34. self.sum = 0
    35. self.count = 0
    36. def update(self, val, n=1):
    37. self.val = val
    38. self.sum += val * n
    39. self.count += n
    40. self.avg = self.sum / self.count
    41. def accuracy(output, target, topk=(1,)):
    42. """Computes the precision@k for the specified values of k"""
    43. maxk = max(topk)
    44. batch_size = target.size(0)
    45. _, pred = output.topk(maxk, 1, True, True)
    46. pred = pred.t()
    47. correct = pred.eq(target.view(1, -1).expand_as(pred))
    48. res = []
    49. for k in topk:
    50. correct_k = correct[:k].view(-1).float().sum(0)
    51. res.append(correct_k.mul_(100.0 / batch_size))
    52. return res
    53. def parse_shift_option_from_log_name(log_name):
    54. if 'shift' in log_name:
    55. strings = log_name.split('_')
    56. for i, s in enumerate(strings):
    57. if 'shift' in s:
    58. break
    59. return True, int(strings[i].replace('shift', '')), strings[i + 1]
    60. else:
    61. return False, None, None
    62. weights_list = args.weights.split(',')
    63. test_segments_list = [int(s) for s in args.test_segments.split(',')]
    64. assert len(weights_list) == len(test_segments_list)
    65. if args.coeff is None:
    66. coeff_list = [1] * len(weights_list)
    67. else:
    68. coeff_list = [float(c) for c in args.coeff.split(',')]
    69. if args.test_list is not None:
    70. test_file_list = args.test_list.split(',')
    71. else:
    72. test_file_list = [None] * len(weights_list)
    73. data_iter_list = []
    74. net_list = []
    75. modality_list = []
    76. total_num = None
    77. for this_weights, this_test_segments, test_file in zip(weights_list, test_segments_list, test_file_list):
    78. is_shift, shift_div, shift_place = parse_shift_option_from_log_name(this_weights)
    79. if 'RGB' in this_weights:
    80. modality = 'RGB'
    81. else:
    82. modality = 'Flow'
    83. this_arch = this_weights.split('TSM_')[1].split('_')[2]
    84. modality_list.append(modality)
    85. num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset(args.dataset,
    86. modality)
    87. print('=> shift: {}, shift_div: {}, shift_place: {}'.format(is_shift, shift_div, shift_place))
    88. net = TSN(num_class, this_test_segments if is_shift else 1, modality,
    89. base_model=this_arch,
    90. consensus_type=args.crop_fusion_type,
    91. img_feature_dim=args.img_feature_dim,
    92. pretrain=args.pretrain,
    93. is_shift=is_shift, shift_div=shift_div, shift_place=shift_place,
    94. non_local='_nl' in this_weights,
    95. )
    96. if 'tpool' in this_weights:
    97. from ops.temporal_shift import make_temporal_pool
    98. make_temporal_pool(net.base_model, this_test_segments) # since DataParallel
    99. checkpoint = torch.load(this_weights)
    100. checkpoint = checkpoint['state_dict']
    101. # base_dict = {('base_model.' + k).replace('base_model.fc', 'new_fc'): v for k, v in list(checkpoint.items())}
    102. base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())}
    103. replace_dict = {'base_model.classifier.weight': 'new_fc.weight',
    104. 'base_model.classifier.bias': 'new_fc.bias',
    105. }
    106. for k, v in replace_dict.items():
    107. if k in base_dict:
    108. base_dict[v] = base_dict.pop(k)
    109. net.load_state_dict(base_dict)
    110. input_size = net.scale_size if args.full_res else net.input_size
    111. if args.test_crops == 1:
    112. cropping = torchvision.transforms.Compose([
    113. GroupScale(net.scale_size),
    114. GroupCenterCrop(input_size),
    115. ])
    116. elif args.test_crops == 3: # do not flip, so only 5 crops
    117. cropping = torchvision.transforms.Compose([
    118. GroupFullResSample(input_size, net.scale_size, flip=False)
    119. ])
    120. elif args.test_crops == 5: # do not flip, so only 5 crops
    121. cropping = torchvision.transforms.Compose([
    122. GroupOverSample(input_size, net.scale_size, flip=False)
    123. ])
    124. elif args.test_crops == 10:
    125. cropping = torchvision.transforms.Compose([
    126. GroupOverSample(input_size, net.scale_size)
    127. ])
    128. else:
    129. raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(args.test_crops))
    130. data_loader = torch.utils.data.DataLoader(
    131. TSNDataSet(root_path, test_file if test_file is not None else val_list, num_segments=this_test_segments,
    132. new_length=1 if modality == "RGB" else 5,
    133. modality=modality,
    134. image_tmpl=prefix,
    135. test_mode=True,
    136. remove_missing=len(weights_list) == 1,
    137. transform=torchvision.transforms.Compose([
    138. cropping,
    139. Stack(roll=(this_arch in ['BNInception', 'InceptionV3'])),
    140. ToTorchFormatTensor(div=(this_arch not in ['BNInception', 'InceptionV3'])),
    141. GroupNormalize(net.input_mean, net.input_std),
    142. ]), dense_sample=args.dense_sample, twice_sample=args.twice_sample),
    143. batch_size=args.batch_size, shuffle=False,
    144. num_workers=args.workers, pin_memory=True,
    145. )
    146. if args.gpus is not None:
    147. devices = [args.gpus[i] for i in range(args.workers)]
    148. else:
    149. devices = list(range(args.workers))
    150. net = torch.nn.DataParallel(net.cuda())
    151. net.eval()
    152. data_gen = enumerate(data_loader)
    153. if total_num is None:
    154. total_num = len(data_loader.dataset)
    155. else:
    156. assert total_num == len(data_loader.dataset)
    157. data_iter_list.append(data_gen)
    158. net_list.append(net)
    159. output = []
    160. def eval_video(video_data, net, this_test_segments, modality):
    161. net.eval()
    162. with torch.no_grad():
    163. i, data, label = video_data
    164. batch_size = label.numel()
    165. num_crop = args.test_crops
    166. if args.dense_sample:
    167. num_crop *= 10 # 10 clips for testing when using dense sample
    168. if args.twice_sample:
    169. num_crop *= 2
    170. if modality == 'RGB':
    171. length = 3
    172. elif modality == 'Flow':
    173. length = 10
    174. elif modality == 'RGBDiff':
    175. length = 18
    176. else:
    177. raise ValueError("Unknown modality "+ modality)
    178. data_in = data.view(-1, length, data.size(2), data.size(3))
    179. if is_shift:
    180. data_in = data_in.view(batch_size * num_crop, this_test_segments, length, data_in.size(2), data_in.size(3))
    181. rst = net(data_in)
    182. rst = rst.reshape(batch_size, num_crop, -1).mean(1)
    183. if args.softmax:
    184. # take the softmax to normalize the output to probability
    185. rst = F.softmax(rst, dim=1)
    186. rst = rst.data.cpu().numpy().copy()
    187. if net.module.is_shift:
    188. rst = rst.reshape(batch_size, num_class)
    189. else:
    190. rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class))
    191. return i, rst, label
    192. proc_start_time = time.time()
    193. max_num = args.max_num if args.max_num > 0 else total_num
    194. top1 = AverageMeter()
    195. top5 = AverageMeter()
    196. for i, data_label_pairs in enumerate(zip(*data_iter_list)):
    197. with torch.no_grad():
    198. if i >= max_num:
    199. break
    200. this_rst_list = []
    201. this_label = None
    202. for n_seg, (_, (data, label)), net, modality in zip(test_segments_list, data_label_pairs, net_list, modality_list):
    203. rst = eval_video((i, data, label), net, n_seg, modality)
    204. this_rst_list.append(rst[1])
    205. this_label = label
    206. assert len(this_rst_list) == len(coeff_list)
    207. for i_coeff in range(len(this_rst_list)):
    208. this_rst_list[i_coeff] *= coeff_list[i_coeff]
    209. ensembled_predict = sum(this_rst_list) / len(this_rst_list)
    210. for p, g in zip(ensembled_predict, this_label.cpu().numpy()):
    211. output.append([p[None, ...], g])
    212. cnt_time = time.time() - proc_start_time
    213. prec1, prec5 = accuracy(torch.from_numpy(ensembled_predict), this_label, topk=(1, 5))
    214. top1.update(prec1.item(), this_label.numel())
    215. top5.update(prec5.item(), this_label.numel())
    216. if i % 20 == 0:
    217. print('video {} done, total {}/{}, average {:.3f} sec/video, '
    218. 'moving Prec@1 {:.3f} Prec@5 {:.3f}'.format(i * args.batch_size, i * args.batch_size, total_num,
    219. float(cnt_time) / (i+1) / args.batch_size, top1.avg, top5.avg))
    220. video_pred = [np.argmax(x[0]) for x in output]
    221. video_pred_top5 = [np.argsort(np.mean(x[0], axis=0).reshape(-1))[::-1][:5] for x in output]
    222. video_labels = [x[1] for x in output]
    223. if args.csv_file is not None:
    224. print('=> Writing result to csv file: {}'.format(args.csv_file))
    225. with open(test_file_list[0].replace('test_videofolder.txt', 'category.txt')) as f:
    226. categories = f.readlines()
    227. categories = [f.strip() for f in categories]
    228. with open(test_file_list[0]) as f:
    229. vid_names = f.readlines()
    230. vid_names = [n.split(' ')[0] for n in vid_names]
    231. assert len(vid_names) == len(video_pred)
    232. if args.dataset != 'somethingv2': # only output top1
    233. with open(args.csv_file, 'w') as f:
    234. for n, pred in zip(vid_names, video_pred):
    235. f.write('{};{}\n'.format(n, categories[pred]))
    236. else:
    237. with open(args.csv_file, 'w') as f:
    238. for n, pred5 in zip(vid_names, video_pred_top5):
    239. fill = [n]
    240. for p in list(pred5):
    241. fill.append(p)
    242. f.write('{};{};{};{};{};{}\n'.format(*fill))
    243. cf = confusion_matrix(video_labels, video_pred).astype(float)
    244. np.save('cm.npy', cf)
    245. cls_cnt = cf.sum(axis=1)
    246. cls_hit = np.diag(cf)
    247. cls_acc = cls_hit / cls_cnt
    248. print(cls_acc)
    249. upper = np.mean(np.max(cf, axis=1) / cls_cnt)
    250. print('upper bound: {}'.format(upper))
    251. print('-----Evaluation is finished------')
    252. print('Class Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100))
    253. print('Overall Prec@1 {:.02f}% Prec@5 {:.02f}%'.format(top1.avg, top5.avg))

    创作不易 觉得有帮助请点赞关注收藏~~~

  • 相关阅读:
    算法竞赛进阶指南——链表学习笔记
    HTML+CSS+JS鲜花商城网页设计期末课程大作业 web前端开发技术 web课程设计 网页规划与设计
    北大学子荣获SRC全球总决赛本科生第一名!完美世界被曝开 17800 元“付费上班”项目;苹果和安卓有望统一充电接口|极客头条
    卸载各种方式安装的K8S集群
    最新WooCommerce教程指南-如何搭建B2C外贸独立站
    企业门户的必备选择,WorkPlus的定制化解决方案
    基于Jsp和Servlet的简单项目
    vue3+ts项目04-国际化
    建筑模板的成本如何控制?
    B2B企业如何打造独立站:从策略到实施的全面指南
  • 原文地址:https://blog.csdn.net/jiebaoshayebuhui/article/details/127856224