【官方】Paddle2.1实现视频理解经典模型 — TSM - 飞桨AI Studio本项目将带大家深入理解视频理解领域经典模型TSM。从模型理论讲解入手,深入到代码实践。实践部分基于TSM模型在UCF101数据集上从训练到推理全流程实现行为识别任务。 - 飞桨AI Studiohttps://aistudio.baidu.com/aistudio/projectdetail/2310889?channelType=0&channel=0视频理解:基于TSM实现UCF101视频理解 - 飞桨AI Studio基于飞桨开源框架构建TSM,并实现对数据集UCF101的视频理解。 - 飞桨AI Studiohttps://aistudio.baidu.com/aistudio/projectdetail/4114499?channelType=0&channel=0
最近一直在做视频相关的项目,后续会陆续出一些视频理解和视频场景运动的案例,视频这块主推paddlevideo,里面应用层面的东西很丰富,paddle在应用侧一直做的比较好,模型训练这块可以结合mmaction2来,其实从实际应用角度来说,我觉得用paddle和pytorch训练都无所谓,部署的话可能以往我的经验更多是onnx,tensort服务侧的,目前来看,主要也就是服务器,端侧和页面侧的部署这三块,我看paddle分别有paddle inference、lite、js,国产框架中确实是首屈一指的,但是我自己的感觉是从我以前训练gan的结果看,paddle貌似要比pytorch的结果,一样的数据,一样的参数配置,好像要差一点。本文主要介绍一下tsm模块,利用2dcnn来模拟时序信息。视频中核心是视频动作识别,本质就是视频分类,可以用作特征提取,视频时序提取是输入一段长视频获取其中的时序片段,时空定位是同时获取视频中的人物物体的空间位置,核心三大任务,除此之外视频特征提取embedding,这块主要是结合多模态去做,视频,音频和文本侧特征的综合利用和提取。
1.时序信息维度
上述这个视频序列从左向右播放和从右向左播放表达的意思是不同的,视频理解对视频顺序是强依赖的。
2.temporal shift module
这个模块是核心,其实tsm是可插拔模块,是可以很好的嵌入到resnet等模型中,上述图中,一种颜色是一帧,按照时序T上,一共是四帧,同一帧横向是一个channel,在cnn中channel是统一做cnn的,在a图中是没有shift的,在b中是离线shift操作,可见将channel中第一个向下移动,第二个向上移动,其实至于上下移动几个channel并没有很严的的限制,通常是分成几等分去移动,这样上下移动之后,则第一个channel会向下突出一帧,第二个channel会向上突出一帧,突出帧直接截断,空缺帧直接补0,这样在横向做cnn时,统一channel维度变引入不同色的帧,tsm正是通过这种平移的方式,TSM在特征图中引入 temporal 维度上的上下文交互,通过通道移动操作可以使得在当前帧中包含了前后两帧的通道信息,这样再进2D卷积操作就能像3D卷积一样直接提取视频的时空信息,提高了模型在时间维度上的建模能力。而online模式用于对视频类型的实时预测,在这种情况下,无法预知下一秒的图像,因此只能将channel维度由过去向现在移动,而不能从未来向现在移动。
3.缺点和改进
虽然时间位移的原理很简单,但作者发现直接将空间位移策略应用于时间维度并不能提供高性能和效率。具体来说,如果简单的转移所有通道,则会带来两个问题:
为了解决naive shift的两个问题,TSM给出了相应的解决方法。
4.mmaction2中的代码
- # Copyright (c) OpenMMLab. All rights reserved.
- import torch
- import torch.nn as nn
- from mmcv.cnn import NonLocal3d
- from torch.nn.modules.utils import _ntuple
-
- from ..builder import BACKBONES
- from .resnet import ResNet
-
-
- class NL3DWrapper(nn.Module):
- """3D Non-local wrapper for ResNet50.
- Wrap ResNet layers with 3D NonLocal modules.
- Args:
- block (nn.Module): Residual blocks to be built.
- num_segments (int): Number of frame segments.
- non_local_cfg (dict): Config for non-local layers. Default: ``dict()``.
- """
-
- def __init__(self, block, num_segments, non_local_cfg=dict()):
- super(NL3DWrapper, self).__init__()
- self.block = block
- self.non_local_cfg = non_local_cfg
- self.non_local_block = NonLocal3d(self.block.conv3.norm.num_features,
- **self.non_local_cfg)
- self.num_segments = num_segments
-
- def forward(self, x):
- x = self.block(x)
-
- n, c, h, w = x.size()
- x = x.view(n // self.num_segments, self.num_segments, c, h,
- w).transpose(1, 2).contiguous()
- x = self.non_local_block(x)
- x = x.transpose(1, 2).contiguous().view(n, c, h, w)
- return x
-
-
- class TemporalShift(nn.Module):
- """Temporal shift module.
- This module is proposed in
- `TSM: Temporal Shift Module for Efficient Video Understanding
-
`_ - Args:
- net (nn.module): Module to make temporal shift.
- num_segments (int): Number of frame segments. Default: 3.
- shift_div (int): Number of divisions for shift. Default: 8.
- """
-
- def __init__(self, net, num_segments=3, shift_div=8):
- super().__init__()
- self.net = net
- self.num_segments = num_segments
- self.shift_div = shift_div
-
- def forward(self, x):
- """Defines the computation performed at every call.
- Args:
- x (torch.Tensor): The input data.
- Returns:
- torch.Tensor: The output of the module.
- """
- x = self.shift(x, self.num_segments, shift_div=self.shift_div)
- return self.net(x)
-
- @staticmethod
- def shift(x, num_segments, shift_div=3):
- """Perform temporal shift operation on the feature.
- Args:
- x (torch.Tensor): The input feature to be shifted.
- num_segments (int): Number of frame segments.
- shift_div (int): Number of divisions for shift. Default: 3.
- Returns:
- torch.Tensor: The shifted feature.
- """
- # 假设当前feature map的通道是256,shift_div=3,
- # 那么就有256/3的特征进行shift left,256/3的特征进行shift right,其他一部分特征不动
- # num_segments每个视频采样的帧数
- # 每帧有c个通道,
- # [
- # [0_1,0_2,0_3,1_1,1_2,3_5,3_6,3_7] 第一帧,8个通道,但是shift_div表示这个通道维度被切分成3个等分
- # [] 第二帧
- # [] 第三帧
- # ]
- # [N, C, H, W]
- n, c, h, w = x.size()
-
- # [N // num_segments, num_segments, C, H*W]
- # can't use 5 dimensional array on PPL2D backend for caffe
- x = x.view(-1, num_segments, c, h * w)
-
- # get shift fold
- fold = c // shift_div
-
- # split c channel into three parts:
- # left_split, mid_split, right_split
- left_split = x[:, :, :fold, :]
- mid_split = x[:, :, fold:2 * fold, :]
- right_split = x[:, :, 2 * fold:, :]
-
- # can't use torch.zeros(*A.shape) or torch.zeros_like(A)
- # because array on caffe inference must be got by computing
-
- # shift left on num_segments channel in `left_split`
- zeros = left_split - left_split
- blank = zeros[:, :1, :, :]
- left_split = left_split[:, 1:, :, :]
- left_split = torch.cat((left_split, blank), 1)
-
- # shift right on num_segments channel in `mid_split`
- zeros = mid_split - mid_split
- blank = zeros[:, :1, :, :]
- mid_split = mid_split[:, :-1, :, :]
- mid_split = torch.cat((blank, mid_split), 1)
-
- # right_split: no shift
-
- # concatenate
- out = torch.cat((left_split, mid_split, right_split), 2)
-
- # [N, C, H, W]
- # restore the original dimension
- return out.view(n, c, h, w)
-
-
- @BACKBONES.register_module()
- class ResNetTSM(ResNet):
- """ResNet backbone for TSM.
- Args:
- num_segments (int): Number of frame segments. Default: 8.
- is_shift (bool): Whether to make temporal shift in reset layers.
- Default: True.
- non_local (Sequence[int]): Determine whether to apply non-local module
- in the corresponding block of each stages. Default: (0, 0, 0, 0).
- non_local_cfg (dict): Config for non-local module. Default: ``dict()``.
- shift_div (int): Number of div for shift. Default: 8.
- shift_place (str): Places in resnet layers for shift, which is chosen
- from ['block', 'blockres'].
- If set to 'block', it will apply temporal shift to all child blocks
- in each resnet layer.
- If set to 'blockres', it will apply temporal shift to each `conv1`
- layer of all child blocks in each resnet layer.
- Default: 'blockres'.
- temporal_pool (bool): Whether to add temporal pooling. Default: False.
- **kwargs (keyword arguments, optional): Arguments for ResNet.
- """
-
- def __init__(self,
- depth,
- num_segments=8,
- is_shift=True,
- non_local=(0, 0, 0, 0),
- non_local_cfg=dict(),
- shift_div=8,
- shift_place='blockres',
- temporal_pool=False,
- **kwargs):
- super().__init__(depth, **kwargs)
- self.num_segments = num_segments
- self.is_shift = is_shift
- self.shift_div = shift_div
- self.shift_place = shift_place
- self.temporal_pool = temporal_pool
- self.non_local = non_local
- self.non_local_stages = _ntuple(self.num_stages)(non_local)
- self.non_local_cfg = non_local_cfg
-
- def make_temporal_shift(self):
- """Make temporal shift for some layers."""
- if self.temporal_pool:
- num_segment_list = [
- self.num_segments, self.num_segments // 2,
- self.num_segments // 2, self.num_segments // 2
- ]
- else:
- num_segment_list = [self.num_segments] * 4
- if num_segment_list[-1] <= 0:
- raise ValueError('num_segment_list[-1] must be positive')
-
- if self.shift_place == 'block':
-
- def make_block_temporal(stage, num_segments):
- """Make temporal shift on some blocks.
- Args:
- stage (nn.Module): Model layers to be shifted.
- num_segments (int): Number of frame segments.
- Returns:
- nn.Module: The shifted blocks.
- """
- blocks = list(stage.children())
- for i, b in enumerate(blocks):
- blocks[i] = TemporalShift(
- b, num_segments=num_segments, shift_div=self.shift_div)
- return nn.Sequential(*blocks)
-
- self.layer1 = make_block_temporal(self.layer1, num_segment_list[0])
- self.layer2 = make_block_temporal(self.layer2, num_segment_list[1])
- self.layer3 = make_block_temporal(self.layer3, num_segment_list[2])
- self.layer4 = make_block_temporal(self.layer4, num_segment_list[3])
-
- elif 'blockres' in self.shift_place:
- n_round = 1
- if len(list(self.layer3.children())) >= 23:
- n_round = 2
-
- def make_block_temporal(stage, num_segments):
- """Make temporal shift on some blocks.
- Args:
- stage (nn.Module): Model layers to be shifted.
- num_segments (int): Number of frame segments.
- Returns:
- nn.Module: The shifted blocks.
- """
- blocks = list(stage.children())
- for i, b in enumerate(blocks):
- if i % n_round == 0:
- blocks[i].conv1.conv = TemporalShift(
- b.conv1.conv,
- num_segments=num_segments,
- shift_div=self.shift_div)
- return nn.Sequential(*blocks)
-
- self.layer1 = make_block_temporal(self.layer1, num_segment_list[0])
- self.layer2 = make_block_temporal(self.layer2, num_segment_list[1])
- self.layer3 = make_block_temporal(self.layer3, num_segment_list[2])
- self.layer4 = make_block_temporal(self.layer4, num_segment_list[3])
-
- else:
- raise NotImplementedError
-
- def make_temporal_pool(self):
- """Make temporal pooling between layer1 and layer2, using a 3D max
- pooling layer."""
-
- class TemporalPool(nn.Module):
- """Temporal pool module.
- Wrap layer2 in ResNet50 with a 3D max pooling layer.
- Args:
- net (nn.Module): Module to make temporal pool.
- num_segments (int): Number of frame segments.
- """
-
- def __init__(self, net, num_segments):
- super().__init__()
- self.net = net
- self.num_segments = num_segments
- self.max_pool3d = nn.MaxPool3d(
- kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0))
-
- def forward(self, x):
- # [N, C, H, W]
- n, c, h, w = x.size()
- # [N // num_segments, C, num_segments, H, W]
- x = x.view(n // self.num_segments, self.num_segments, c, h,
- w).transpose(1, 2)
- # [N // num_segmnets, C, num_segments // 2, H, W]
- x = self.max_pool3d(x)
- # [N // 2, C, H, W]
- x = x.transpose(1, 2).contiguous().view(n // 2, c, h, w)
- return self.net(x)
-
- self.layer2 = TemporalPool(self.layer2, self.num_segments)
-
- def make_non_local(self):
- # This part is for ResNet50
- for i in range(self.num_stages):
- non_local_stage = self.non_local_stages[i]
- if sum(non_local_stage) == 0:
- continue
-
- layer_name = f'layer{i + 1}'
- res_layer = getattr(self, layer_name)
-
- for idx, non_local in enumerate(non_local_stage):
- if non_local:
- res_layer[idx] = NL3DWrapper(res_layer[idx],
- self.num_segments,
- self.non_local_cfg)
-
- def init_weights(self):
- """Initiate the parameters either from existing checkpoint or from
- scratch."""
- super().init_weights()
- if self.is_shift:
- self.make_temporal_shift()
- if len(self.non_local_cfg) != 0:
- self.make_non_local()
- if self.temporal_pool:
- self.make_temporal_pool()