• 【手把手反内卷】开创全新AI多模态任务一视听分割:代码实践、优化教程(二)


    前言

    理论部分请看上一篇文章:

    简要概述:我们要知道图像中哪个物体在发声如下视频演示:

    gif 不能发出声音,大家脑补一下场景中有很多车,只有这辆120在发出声音,所以分割出发出声音的物体。

     

     

     这是一位歌手时而唱歌,时而弹琴场景,只弹琴时,不分割人体,唱歌时,分割人体。

     

    代码相对路径介绍(我的版本,非官方)

     

    大家可以通过下载我的百度网盘(附带全部数据和代码),也可以下载官方代码,但不含数据,只能申请得到。

    训练

    先看train.py

    看下面代码的help里面。

    1. parser.add_argument("--session_name", default="MS3", type=str, help="使用MS3是对数据里的Multi-sources下的数据进行训练,是多声源数据,也就是,可能同时有多个物体发声")
    2. parser.add_argument("--visual_backbone", default="resnet", type=str,
    3.                    help="use resnet50 or pvt-v2 as the visual backbone")
    4. parser.add_argument("--train_batch_size", default=4, type=int)
    5. parser.add_argument("--val_batch_size", default=1, type=int)
    6. parser.add_argument("--max_epoches", default=5, type=int)
    7. parser.add_argument("--lr", default=0.0001, type=float)
    8. parser.add_argument("--num_workers", default=0, type=int)
    9. parser.add_argument("--wt_dec", default=5-4, type=float)
    10. parser.add_argument('--masked_av_flag', action='store_true', default=True,
    11.                    help='使用作者论文里说的loss: sa/masked_va loss')
    12. parser.add_argument("--lambda_1", default=0.5, type=float, help='均衡系数weight for balancing l4 loss')
    13. parser.add_argument("--masked_av_stages", default=[0, 1, 2, 3], nargs='+', type=int,
    14.                    help='作者的设置compute sa/masked_va loss in which stages: [0, 1, 2, 3]')
    15. parser.add_argument('--threshold_flag', action='store_true', default=False,
    16.                    help='whether thresholding the generated masks')
    17. parser.add_argument("--mask_pooling_type", default='avg', type=str, help='the manner to downsample predicted masks')
    18. parser.add_argument('--norm_fea_flag', action='store_true', default=False, help='音频标准化normalize audio-visual features')
    19. parser.add_argument('--closer_flag', action='store_true', default=False, help='use closer loss for masked_va loss')
    20. parser.add_argument('--euclidean_flag', action='store_true', default=False,
    21.                    help='use euclidean distance for masked_va loss')
    22. parser.add_argument('--kl_flag', action='store_true', default=True, help='KL散度 use kl loss for masked_va loss')
    23. parser.add_argument("--load_s4_params", action='store_true', default=False,
    24.                    help='use S4 parameters for initilization')
    25. parser.add_argument("--trained_s4_model_path", type=str, default='', help='pretrained S4 model')
    26. parser.add_argument("--tpavi_stages", default=[0, 1, 2, 3], nargs='+', type=int,
    27.                    help='tpavi模块 add tpavi block in which stages: [0, 1, 2, 3]')
    28. parser.add_argument("--tpavi_vv_flag", action='store_true', default=False, help='视觉自注意visual-visual self-attention')
    29. parser.add_argument("--tpavi_va_flag", action='store_true', default=True, help='视听交叉注意visual-audio cross-attention')
    30. parser.add_argument("--weights", type=str, default='', help='初始训练预训练模型,可以不写path of trained model')
    31. parser.add_argument('--log_dir', default='./train_logs', type=str)

    大家根据train.sh就可以训练

    代码细节

    接下来会根据设置你要的视觉特征提取backbone,语音的默认使用vggish特征提取。

    1. if (args.visual_backbone).lower() == "resnet":
    2.    from model import ResNet_AVSModel as AVSModel
    3.    print('==> Use ResNet50 as the visual backbone...')
    4. elif (args.visual_backbone).lower() == "pvt":
    5.    from model import PVT_AVSModel as AVSModel
    6.    print('==> Use pvt-v2 as the visual backbone...')
    7. else:
    8.    raise NotImplementedError("only support the resnet50 and pvt-v2")

    数据读取部分:

    1. class MS3Dataset(Dataset):
    2.    """Dataset for multiple sound source segmentation"""
    3.    def __init__(self, split='train'):
    4.        super(MS3Dataset, self).__init__()
    5.        self.split = split
    6.        self.mask_num = 5
    7.        df_all = pd.read_csv(cfg.DATA.ANNO_CSV, sep=',')
    8.        self.df_split = df_all[df_all['split'] == split]
    9.        print("{}/{} videos are used for {}".format(len(self.df_split), len(df_all), self.split))
    10.        self.img_transform = transforms.Compose([
    11.            transforms.ToTensor(),
    12.            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    13.       ])
    14.        self.mask_transform = transforms.Compose([
    15.            transforms.ToTensor(),
    16.       ])
    17.    def __getitem__(self, index):
    18.        df_one_video = self.df_split.iloc[index]
    19.        video_name = df_one_video[0]
    20.        img_base_path =  os.path.join(cfg.DATA.DIR_IMG, video_name)
    21.        audio_lm_path = os.path.join(cfg.DATA.DIR_AUDIO_LOG_MEL, self.split, video_name + '.pkl')
    22.        mask_base_path = os.path.join(cfg.DATA.DIR_MASK, self.split, video_name)
    23.        audio_log_mel = load_audio_lm(audio_lm_path)
    24.        # audio_lm_tensor = torch.from_numpy(audio_log_mel)
    25.        imgs, masks = [], []
    26.        for img_id in range(1, 6):
    27.            img = load_image_in_PIL_to_Tensor(os.path.join(img_base_path, "%s.mp4_%d.png"%(video_name, img_id)), transform=self.img_transform)
    28.            imgs.append(img)
    29.        for mask_id in range(1, self.mask_num + 1):
    30.            mask = load_image_in_PIL_to_Tensor(os.path.join(mask_base_path, "%s_%d.png"%(video_name, mask_id)), transform=self.mask_transform, mode='P')
    31.            masks.append(mask)
    32.        imgs_tensor = torch.stack(imgs, dim=0)
    33.        masks_tensor = torch.stack(masks, dim=0)
    34.        return imgs_tensor, audio_log_mel, masks_tensor, video_name
    35.    def __len__(self):
    36.        return len(self.df_split)

    可以看到,一次读取5张图,我看了视频,都是5秒的,说明作者一次训练一个视频,每个视频每秒的帧和GT和语音合并训练。

    1. for n_iter, batch_data in enumerate(train_dataloader):
    2.    imgs, audio, mask, _ = batch_data  # [bs, 5, 3, 224, 224], [bs, 5, 1, 96, 64], [bs, 5 or 1, 1, 224, 224]
    3.    imgs = imgs.cuda()
    4.    audio = audio.cuda()
    5.    mask = mask.cuda()
    6.    B, frame, C, H, W = imgs.shape
    7.    imgs = imgs.view(B * frame, C, H, W)
    8.    mask_num = 5
    9.    mask = mask.view(B * mask_num, 1, H, W)
    10.    audio = audio.view(-1, audio.shape[2], audio.shape[3], audio.shape[4])  # [B*T, 1, 96, 64]
    11.    with torch.no_grad():
    12.        audio_feature = audio_backbone(audio)  # [B*T, 128]
    13.    output, v_map_list, a_fea_list = model(imgs, audio_feature)  # [bs*5, 1, 224, 224]
    14.    loss, loss_dict = IouSemanticAwareLoss(output, mask, a_fea_list, v_map_list, \
    15.                                           sa_loss_flag=args.masked_av_flag, lambda_1=args.lambda_1,
    16.                                           count_stages=args.masked_av_stages, \
    17.                                           mask_pooling_type=args.mask_pooling_type,
    18.                                           threshold=args.threshold_flag, norm_fea=args.norm_fea_flag, \
    19.                                           closer_flag=args.closer_flag, euclidean_flag=args.euclidean_flag,
    20.                                           kl_flag=args.kl_flag)
    21.    avg_meter_total_loss.add({'total_loss': loss.item()})
    22.    avg_meter_iou_loss.add({'iou_loss': loss_dict['iou_loss']})
    23.    avg_meter_sa_loss.add({'sa_loss': loss_dict['sa_loss']})
    24.    optimizer.zero_grad()
    25.    loss.backward()
    26.    optimizer.step()
    27.    global_step += 1
    28.    if (global_step - 1) % 20 == 0:
    29.        train_log = 'Iter:%5d/%5d, Total_Loss:%.4f, iou_loss:%.4f, sa_loss:%.4f, lr: %.4f' % (
    30.            global_step - 1, max_step, avg_meter_total_loss.pop('total_loss'),
    31.            avg_meter_iou_loss.pop('iou_loss'), avg_meter_sa_loss.pop('sa_loss'),
    32.            optimizer.param_groups[0]['lr'])

    可以看到,训练很简单,先load图像5帧view合并在一起,再获取语音特征,送入模型。然后计算损失和Iou得分。

    输入模型的数据分为两部分,图像帧【bs*5, 3, 224, 224】,乘以5意思是每个视频有5帧,第二部分是语音帧,维度相似。

    1. class Pred_endecoder(nn.Module):
    2.    # resnet based encoder decoder
    3.    def __init__(self, channel=256, config=None, tpavi_stages=[], tpavi_vv_flag=False, tpavi_va_flag=True):
    4.        super(Pred_endecoder, self).__init__()
    5.        self.cfg = config
    6.        self.tpavi_stages = tpavi_stages
    7.        self.tpavi_vv_flag = tpavi_vv_flag
    8.        self.tpavi_va_flag = tpavi_va_flag
    9.        self.resnet = B2_ResNet()
    10.        self.relu = nn.ReLU(inplace=True)
    11.        self.conv4 = self._make_pred_layer(Classifier_Module, [3, 6, 12, 18], [3, 6, 12, 18], channel, 2048)
    12.        self.conv3 = self._make_pred_layer(Classifier_Module, [3, 6, 12, 18], [3, 6, 12, 18], channel, 1024)
    13.        self.conv2 = self._make_pred_layer(Classifier_Module, [3, 6, 12, 18], [3, 6, 12, 18], channel, 512)
    14.        self.conv1 = self._make_pred_layer(Classifier_Module, [3, 6, 12, 18], [3, 6, 12, 18], channel, 256)
    15.        self.path4 = FeatureFusionBlock(channel)
    16.        self.path3 = FeatureFusionBlock(channel)
    17.        self.path2 = FeatureFusionBlock(channel)
    18.        self.path1 = FeatureFusionBlock(channel)
    19.        for i in self.tpavi_stages:
    20.            setattr(self, f"tpavi_b{i + 1}", TPAVIModule(in_channels=channel, mode='dot'))
    21.            print("==> Build TPAVI block...")
    22.        self.output_conv = nn.Sequential(
    23.            nn.Conv2d(channel, 128, kernel_size=3, stride=1, padding=1),
    24.            Interpolate(scale_factor=2, mode="bilinear"),
    25.            nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
    26.            nn.ReLU(True),
    27.            nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
    28.       )
    29.        if self.training:
    30.            self.initialize_weights()
    31.    def pre_reshape_for_tpavi(self, x):
    32.        # x: [B*5, C, H, W]
    33.        _, C, H, W = x.shape
    34.        x = x.reshape(-1, 5, C, H, W)
    35.        x = x.permute(0, 2, 1, 3, 4).contiguous()  # [B, C, T, H, W]
    36.        return x
    37.    def post_reshape_for_tpavi(self, x):
    38.        # x: [B, C, T, H, W]
    39.        # return: [B*T, C, H, W]
    40.        _, C, _, H, W = x.shape
    41.        x = x.permute(0, 2, 1, 3, 4)  # [B, T, C, H, W]
    42.        x = x.view(-1, C, H, W)
    43.        return x
    44.    def tpavi_vv(self, x, stage):
    45.        # x: visual, [B*T, C=256, H, W]
    46.        tpavi_b = getattr(self, f'tpavi_b{stage + 1}')
    47.        x = self.pre_reshape_for_tpavi(x)  # [B, C, T, H, W]
    48.        x, _ = tpavi_b(x)  # [B, C, T, H, W]
    49.        x = self.post_reshape_for_tpavi(x)  # [B*T, C, H, W]
    50.        return x
    51.    def tpavi_va(self, x, audio, stage):
    52.        # x: visual, [B*T, C=256, H, W]
    53.        # audio: [B*T, 128]
    54.        # ra_flag: return audio feature list or not
    55.        tpavi_b = getattr(self, f'tpavi_b{stage + 1}')
    56.        audio = audio.view(-1, 5, audio.shape[-1])  # [B, T, 128]
    57.        x = self.pre_reshape_for_tpavi(x)  # [B, C, T, H, W]
    58.        x, a = tpavi_b(x, audio)  # [B, C, T, H, W], [B, T, C]
    59.        x = self.post_reshape_for_tpavi(x)  # [B*T, C, H, W]
    60.        return x, a
    61.    def _make_pred_layer(self, block, dilation_series, padding_series, NoLabels, input_channel):
    62.        return block(dilation_series, padding_series, NoLabels, input_channel)
    63.    def forward(self, x, audio_feature=None):
    64.        x = self.resnet.conv1(x)
    65.        x = self.resnet.bn1(x)
    66.        x = self.resnet.relu(x)
    67.        x = self.resnet.maxpool(x)
    68.        x1 = self.resnet.layer1(x)  # BF x 256 x 56 x 56
    69.        x2 = self.resnet.layer2(x1)  # BF x 512 x 28 x 28
    70.        x3 = self.resnet.layer3_1(x2)  # BF x 1024 x 14 x 14
    71.        x4 = self.resnet.layer4_1(x3)  # BF x 2048 x 7 x 7
    72.        # print(x1.shape, x2.shape, x3.shape, x4.shape)
    73.        conv1_feat = self.conv1(x1)  # BF x 256 x 56 x 56
    74.        conv2_feat = self.conv2(x2)  # BF x 256 x 28 x 28
    75.        conv3_feat = self.conv3(x3)  # BF x 256 x 14 x 14
    76.        conv4_feat = self.conv4(x4)  # BF x 256 x 7 x 7
    77.        # print(conv1_feat.shape, conv2_feat.shape, conv3_feat.shape, conv4_feat.shape)
    78.        feature_map_list = [conv1_feat, conv2_feat, conv3_feat, conv4_feat]
    79.        a_fea_list = [None] * 4
    80.        if len(self.tpavi_stages) > 0:
    81.            if (not self.tpavi_vv_flag) and (not self.tpavi_va_flag):
    82.                raise Exception('tpavi_vv_flag and tpavi_va_flag cannot be False at the same time if len(tpavi_stages)>0, \
    83.                   tpavi_vv_flag is for video self-attention while tpavi_va_flag indicates the standard version (audio-visual attention)')
    84.            for i in self.tpavi_stages:
    85.                tpavi_count = 0
    86.                conv_feat = torch.zeros_like(feature_map_list[i]).cuda()
    87.                if self.tpavi_vv_flag:
    88.                    conv_feat_vv = self.tpavi_vv(feature_map_list[i], stage=i)
    89.                    conv_feat += conv_feat_vv
    90.                    tpavi_count += 1
    91.                if self.tpavi_va_flag:
    92.                    conv_feat_va, a_fea = self.tpavi_va(feature_map_list[i], audio_feature, stage=i)
    93.                    conv_feat += conv_feat_va
    94.                    tpavi_count += 1
    95.                    a_fea_list[i] = a_fea
    96.                conv_feat /= tpavi_count
    97.                feature_map_list[i] = conv_feat  # update features of stage-i which conduct TPAVI
    98.        conv4_feat = self.path4(feature_map_list[3])  # BF x 256 x 14 x 14
    99.        conv43 = self.path3(conv4_feat, feature_map_list[2])  # BF x 256 x 28 x 28
    100.        conv432 = self.path2(conv43, feature_map_list[1])  # BF x 256 x 56 x 56
    101.        conv4321 = self.path1(conv432, feature_map_list[0])  # BF x 256 x 112 x 112
    102.        # print(conv4_feat.shape, conv43.shape, conv432.shape, conv4321.shape)
    103.        pred = self.output_conv(conv4321)  # BF x 1 x 224 x 224
    104.        # print(pred.shape)
    105.        return pred, feature_map_list, a_fea_list
    106.    def initialize_weights(self):
    107.        res50 = models.resnet50(pretrained=False)
    108.        resnet50_dict = torch.load(self.cfg.TRAIN.PRETRAINED_RESNET50_PATH)
    109.        res50.load_state_dict(resnet50_dict)
    110.        pretrained_dict = res50.state_dict()
    111.        # print(pretrained_dict.keys())
    112.        all_params = {}
    113.        for k, v in self.resnet.state_dict().items():
    114.            if k in pretrained_dict.keys():
    115.                v = pretrained_dict[k]
    116.                all_params[k] = v
    117.            elif '_1' in k:
    118.                name = k.split('_1')[0] + k.split('_1')[1]
    119.                v = pretrained_dict[name]
    120.                all_params[k] = v
    121.            elif '_2' in k:
    122.                name = k.split('_2')[0] + k.split('_2')[1]
    123.                v = pretrained_dict[name]
    124.                all_params[k] = v
    125.        assert len(all_params.keys()) == len(self.resnet.state_dict().keys())
    126.        self.resnet.load_state_dict(all_params)
    127.        print(f'==> Load pretrained ResNet50 parameters from {self.cfg.TRAIN.PRETRAINED_RESNET50_PATH}')

    网络部分很简单,模型的定义没什么亮点,我们看forward里面的代码:

    1. def forward(self, x, audio_feature=None):  # 输入图像帧和音频梅尔图经过vggish 的特征图。
    2.    x = self.resnet.conv1(x)
    3.    x = self.resnet.bn1(x)
    4.    x = self.resnet.relu(x)
    5.    x = self.resnet.maxpool(x)
    6.    x1 = self.resnet.layer1(x)  # BF x 256 x 56 x 56
    7.    x2 = self.resnet.layer2(x1)  # BF x 512 x 28 x 28
    8.    x3 = self.resnet.layer3_1(x2)  # BF x 1024 x 14 x 14
    9.    x4 = self.resnet.layer4_1(x3)  # BF x 2048 x 7 x 7 先进行resnet特征提取
    10.    # print(x1.shape, x2.shape, x3.shape, x4.shape)
    11.    conv1_feat = self.conv1(x1)  # BF x 256 x 56 x 56   维度转换一下
    12.    conv2_feat = self.conv2(x2)  # BF x 256 x 28 x 28
    13.    conv3_feat = self.conv3(x3)  # BF x 256 x 14 x 14
    14.    conv4_feat = self.conv4(x4)  # BF x 256 x 7 x 7
    15.    # print(conv1_feat.shape, conv2_feat.shape, conv3_feat.shape, conv4_feat.shape)
    16.    feature_map_list = [conv1_feat, conv2_feat, conv3_feat, conv4_feat]
    17.    a_fea_list = [None] * 4
    18.    if len(self.tpavi_stages) > 0:   # 做几次tpavi模块,论文中是4次
    19.        if (not self.tpavi_vv_flag) and (not self.tpavi_va_flag):
    20.            raise Exception('tpavi_vv_flag and tpavi_va_flag cannot be False at the same time if len(tpavi_stages)>0, \
    21.               tpavi_vv_flag is for video self-attention while tpavi_va_flag indicates the standard version (audio-visual attention)')
    22.        for i in self.tpavi_stages:
    23.            tpavi_count = 0
    24.            conv_feat = torch.zeros_like(feature_map_list[i]).cuda()
    25.            if self.tpavi_vv_flag:
    26.                conv_feat_vv = self.tpavi_vv(feature_map_list[i], stage=i)
    27.                conv_feat += conv_feat_vv
    28.                tpavi_count += 1
    29.            if self.tpavi_va_flag:
    30.                # tpavi模块
    31.                conv_feat_va, a_fea = self.tpavi_va(feature_map_list[i], audio_feature, stage=i)  
    32.                conv_feat += conv_feat_va
    33.                tpavi_count += 1
    34.                a_fea_list[i] = a_fea
    35.            conv_feat /= tpavi_count
    36.            feature_map_list[i] = conv_feat  # update features of stage-i which conduct TPAVI
    37.    conv4_feat = self.path4(feature_map_list[3])  # BF x 256 x 14 x 14 # 解码
    38.    conv43 = self.path3(conv4_feat, feature_map_list[2])  # BF x 256 x 28 x 28
    39.    conv432 = self.path2(conv43, feature_map_list[1])  # BF x 256 x 56 x 56
    40.    conv4321 = self.path1(conv432, feature_map_list[0])  # BF x 256 x 112 x 112
    41.    # print(conv4_feat.shape, conv43.shape, conv432.shape, conv4321.shape)
    42.    pred = self.output_conv(conv4321)  # BF x 1 x 224 x 224
    43.    # print(pred.shape)
    44.    return pred, feature_map_list, a_fea_list

    可以看到要经过一个TPAVI模块,是蛮复杂的模块:

    1. class TPAVIModule(nn.Module):
    2.    def __init__(self, in_channels, inter_channels=None, mode='dot',
    3.                 dimension=3, bn_layer=True):
    4.        """
    5.       args:
    6.           in_channels: original channel size (1024 in the paper)
    7.           inter_channels: channel size inside the block if not specifed reduced to half (512 in the paper)
    8.           mode: supports Gaussian, Embedded Gaussian, Dot Product, and Concatenation
    9.           dimension: can be 1 (temporal), 2 (spatial), 3 (spatiotemporal)
    10.           bn_layer: whether to add batch norm
    11.       """
    12.        super(TPAVIModule, self).__init__()
    13.        assert dimension in [1, 2, 3]
    14.        
    15.        if mode not in ['gaussian', 'embedded', 'dot', 'concatenate']:
    16.            raise ValueError('`mode` must be one of `gaussian`, `embedded`, `dot` or `concatenate`')
    17.            
    18.        self.mode = mode
    19.        self.dimension = dimension
    20.        self.in_channels = in_channels
    21.        self.inter_channels = inter_channels
    22.        # the channel size is reduced to half inside the block
    23.        if self.inter_channels is None:
    24.            self.inter_channels = in_channels // 2
    25.            if self.inter_channels == 0:
    26.                self.inter_channels = 1
    27.        
    28.        ## add align channel
    29.        self.align_channel = nn.Linear(128, in_channels)
    30.        self.norm_layer=nn.LayerNorm(in_channels)
    31.        # assign appropriate convolutional, max pool, and batch norm layers for different dimensions
    32.        if dimension == 3:
    33.            conv_nd = nn.Conv3d
    34.            max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
    35.            bn = nn.BatchNorm3d
    36.        elif dimension == 2:
    37.            conv_nd = nn.Conv2d
    38.            max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
    39.            bn = nn.BatchNorm2d
    40.        else:
    41.            conv_nd = nn.Conv1d
    42.            max_pool_layer = nn.MaxPool1d(kernel_size=(2))
    43.            bn = nn.BatchNorm1d
    44.        # function g in the paper which goes through conv. with kernel size 1
    45.        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)
    46.        if bn_layer:
    47.            self.W_z = nn.Sequential(
    48.                    conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1),
    49.                    bn(self.in_channels)
    50.               )
    51.            nn.init.constant_(self.W_z[1].weight, 0)
    52.            nn.init.constant_(self.W_z[1].bias, 0)
    53.        else:
    54.            self.W_z = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1)
    55.            nn.init.constant_(self.W_z.weight, 0)
    56.            nn.init.constant_(self.W_z.bias, 0)
    57.        # define theta and phi for all operations except gaussian
    58.        if self.mode == "embedded" or self.mode == "dot" or self.mode == "concatenate":
    59.            self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)
    60.            self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1)
    61.        
    62.        if self.mode == "concatenate":
    63.            self.W_f = nn.Sequential(
    64.                    nn.Conv2d(in_channels=self.inter_channels * 2, out_channels=1, kernel_size=1),
    65.                    nn.ReLU()
    66.               )
    67.            
    68.    def forward(self, x, audio=None):
    69.        """
    70.       args:
    71.           x: (N, C, T, H, W) for dimension=3; (N, C, H, W) for dimension 2; (N, C, T) for dimension 1
    72.           audio: (N, T, C)
    73.       """
    74.        audio_temp = 0
    75.        batch_size, C = x.size(0), x.size(1)
    76.        if audio is not None:
    77.            # print('==> audio.shape', audio.shape)
    78.            H, W = x.shape[-2], x.shape[-1]
    79.            audio_temp = self.align_channel(audio) # [bs, T, C]
    80.            audio = audio_temp.permute(0, 2, 1) # [bs, C, T]
    81.            audio = audio.unsqueeze(-1).unsqueeze(-1) # [bs, C, T, 1, 1]
    82.            audio = audio.repeat(1, 1, 1, H, W) # [bs, C, T, H, W]
    83.        else:
    84.            audio = x
    85.        # (N, C, THW)
    86.        g_x = self.g(x).view(batch_size, self.inter_channels, -1) # [bs, C, THW]
    87.        # print('g_x.shape', g_x.shape)
    88.        # g_x = x.view(batch_size, C, -1) # [bs, C, THW]
    89.        g_x = g_x.permute(0, 2, 1) # [bs, THW, C]
    90.        if self.mode == "gaussian":
    91.            theta_x = x.view(batch_size, self.in_channels, -1)
    92.            phi_x = audio.view(batch_size, self.in_channels, -1)
    93.            theta_x = theta_x.permute(0, 2, 1)
    94.            f = torch.matmul(theta_x, phi_x)
    95.        elif self.mode == "embedded" or self.mode == "dot":
    96.            theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) # [bs, C', THW]
    97.            phi_x = self.phi(audio).view(batch_size, self.inter_channels, -1) # [bs, C', THW]
    98.            theta_x = theta_x.permute(0, 2, 1) # [bs, THW, C']
    99.            f = torch.matmul(theta_x, phi_x) # [bs, THW, THW]
    100.        elif self.mode == "concatenate":
    101.            theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1)
    102.            phi_x = self.phi(audio).view(batch_size, self.inter_channels, 1, -1)
    103.            
    104.            h = theta_x.size(2)
    105.            w = phi_x.size(3)
    106.            theta_x = theta_x.repeat(1, 1, 1, w)
    107.            phi_x = phi_x.repeat(1, 1, h, 1)
    108.            
    109.            concat = torch.cat([theta_x, phi_x], dim=1)
    110.            f = self.W_f(concat)
    111.            f = f.view(f.size(0), f.size(2), f.size(3))
    112.        
    113.        if self.mode == "gaussian" or self.mode == "embedded":
    114.            f_div_C = F.softmax(f, dim=-1)
    115.        elif self.mode == "dot" or self.mode == "concatenate":
    116.            N = f.size(-1) # number of position in x
    117.            f_div_C = f / N  # [bs, THW, THW]
    118.        
    119.        y = torch.matmul(f_div_C, g_x) # [bs, THW, C]
    120.        
    121.        # contiguous here just allocates contiguous chunk of memory
    122.        y = y.permute(0, 2, 1).contiguous() # [bs, C, THW]
    123.        y = y.view(batch_size, self.inter_channels, *x.size()[2:]) # [bs, C', T, H, W]
    124.        
    125.        W_y = self.W_z(y)  # [bs, C, T, H, W]
    126.        # residual connection
    127.        z = W_y + x # # [bs, C, T, H, W]
    128.        # add LayerNorm
    129.        z =  z.permute(0, 2, 3, 4, 1) # [bs, T, H, W, C]
    130.        z = self.norm_layer(z)
    131.        z = z.permute(0, 4, 1, 2, 3) # [bs, C, T, H, W]
    132.        
    133.        return z, audio_temp

    代码看着复杂,其实是作者做了很多的模块选择以及代码的通道转换,实际最后的操作就是几个1* 1 *1 3D卷积,咱不用想也知道,3d卷积来做时序的特征提取。然后做一些累乘累加操作。

    1. if dimension == 3:
    2.    conv_nd = nn.Conv3d
    3.    max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
    4.    bn = nn.BatchNorm3d
    5. elif dimension == 2:
    6.    conv_nd = nn.Conv2d
    7.    max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
    8.    bn = nn.BatchNorm2d
    9. else:
    10.    conv_nd = nn.Conv1d
    11.    max_pool_layer = nn.MaxPool1d(kernel_size=(2))
    12.    bn = nn.BatchNorm1d

    最后经过几个解码器,将特征图转为一维度:

    1. conv4_feat = self.path4(feature_map_list[3])  # BF x 256 x 14 x 14
    2. conv43 = self.path3(conv4_feat, feature_map_list[2])  # BF x 256 x 28 x 28
    3. conv432 = self.path2(conv43, feature_map_list[1])  # BF x 256 x 56 x 56
    4. conv4321 = self.path1(conv432, feature_map_list[0])  # BF x 256 x 112 x 112
    5. # print(conv4_feat.shape, conv43.shape, conv432.shape, conv4321.shape)
    6. pred = self.output_conv(conv4321)  # BF x 1 x 224 x 224

    可以看到【BF x 1 x 224 x 224】这个1维度的变化,就是网络的一个回归预测部分。最后输出的bs *frame 张1 * 224 *224 的图,就是我们最后输出的图(经过argmax等操作后显示成0,1分类),就变成了预测的mask图,

    大家可以看我的预测图:

     

    测试

    先看看ms3_meta_data.csv 的数据

     

    可以看到,一共有三份数据:训练、验证和测试集,我们训练好模型后,可以使用test.py 进行测试,测试效果会放在test_log文件夹。会去测试,test文件夹里的数据。运行测试代码,改一下训练好的模型路径就可以看到结果。

    测试某个视频

    点开avsbench_data/det/det的raw_videos/里面放你想测试的videos,建议5s,因为要切5帧,除非你改代码。

    然后运行preprocess_scripts/preprocess_ms3.py,这是为了生成语音的梅尔图,和切帧,会保存到raw_videos同级。

    接着运行detect.py(在train.py 同级)就可以针对你的视频,推理了。

    实时检测,这个代码我还在写,稍等。

    代码所有的链接(本地文件不能上传,只能提供原始github):https://github.com/OpenNLPLab/AVSBench

    最后

    近期我会录制视频,过一遍原理和代码和训练推理,大家关注一下~

  • 相关阅读:
    Pow(x, n)
    C++之面向对象
    【PostgreSQL 15】PostgreSQL 15对UNIQUE和NULL的改进
    纳米软件介绍什么是LABVIEW?
    3D格式转换工具HOOPS Exchange最全技术指南(四):4大功能特征与典型使用场景
    LeetCode链表练习(中)
    【Python】PySpark 数据输入 ① ( RDD 简介 | RDD 中的数据存储与计算 | Python 容器数据转 RDD 对象 | 文件文件转 RDD 对象 )
    网站收录量与索引量有什么区别?
    [vue3] 使用ElementPlus页面布局搭建架子
    攻防世界misc
  • 原文地址:https://blog.csdn.net/qq_46098574/article/details/126255334