• 【三维重建-PatchMatchNet复现笔记】


    原论文链接:PatchmatchNet: Learned Multi-View Patchmatch Stereo
    以下讲解可以逐步对照源代码,链接:PatchmatchNet源代码

    1 突出贡献

    在这里插入图片描述
    计算机GPU和运行时间受限的情况下,PatchMatchNet测试DTU数据集能以较低GPU内存和较低运行时间,整体误差位列中等,成为2020年多视图三维重建(MVS,Multi-view Stereo)的折中方案.

    特点:
    高速,低内存,可以处理更高分辨率的图像,它的效率比现有的模型都要好得多: 比最先进的方法至少快2.5倍,内存使用量减少一倍。
    首次在端到端可训练架构中引入了迭代的多尺度Patchmatch,并用一种新颖的、可学习的自适应传播和每次迭代的评估方案改进了传统Patchmatch核心算法。

    主要贡献
    基于学习的方法比传统的方法有优势,但是受限于内存和运行时间,于是将补丁匹配的想法引入到端到端可训练的深度学习中,用可学习的自适应模块增强了补丁匹配的传统传播和代价评估步骤,减少了内存消耗和运行时间。

    2 数据集描述

    (1)在学习PatchMatchNet之前,先了解DTU数据集的特点有助于理解算法的实现步骤,DTU数据集是一种在特定条件下拍摄的多视图数据集。其包含128种物体的多视图,分别使用64个固定的相机(表明有64个相机内、外参数)拍摄具有一定重合区域的图片。相机参数如下形式:

    extrinsic(外参:旋转矩阵R、T)
    0.126794 -0.880314 0.457133 -272.105
    0.419456 0.465205 0.779513 -485.147
    -0.898877 0.09291 0.428238 629.679
    0.0 0.0 0.0 1.0
    
    intrinsic(内参:针孔相机的["fx", "fy", "cx", "cy"]2892.33 0 823.206
    0 2883.18 619.07
    0 0 1
    
    425 2.5(深度的最小、最大范围值)原代码的深度顺序是先小后大
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    (2)使用COLMAP软件(使用方法自查)观察物体scan1的拍摄形式如下图所示 ,其中包含49张帽子图片,故在49个摄像位置进行拍摄,抓取特征点并匹配,重建产生26134个点的稀疏点云:
    在这里插入图片描述
    (3)训练的数据目录结构如下形式:

    训练数据根目录
    	+---Cameras_1(相机参数)
    	|   +---00000000_cam.txt
    	|  	+---00000001_cam.txt
    	|  	+---00000002_cam.txt
    	|   ......64个相机参数txt文件(有些相机位是没有用到的)
    	|	+---pair.txt(视图之间重合区域匹配文件(1个))
    	|   \---train(内含64个相机参数txt文件)
    	|   	+---00000000_cam.txt
    	|  		+---00000001_cam.txt
    	|  		+---00000002_cam.txt
    	|		......
    	+---Depths_raw(深度图)
    	|   +---scan1
    	|       +---depth_map_0000.pfm(pfm格式的深度图:宽160*128|       +---depth_map_0001.pfm
    	|       +---depth_map_0002.pfm
    	|       +---depth_map_0003.pfm
    	|       ......
    	|       +---depth_visual_0044.png(png格式的深度图可视化黑白图:宽160*128|     	+---depth_visual_0045.png
    	|      	+---depth_visual_0046.png
    	|      	+---depth_visual_0047.png
    	|      	+---depth_visual_0048.png
    	|       ......
    	|   +---scan2
    	|   +---scan3
    	|   +---scan4
    	|   +---scan5
    	|   +---scan6
    	|   +---scan7
    	|   \---scan8
    	\---Rectified
    	    +---scan1_train
    	    	+---rect_001_0_r5000.png
                +---rect_001_1_r5000.png
                +---rect_001_2_r5000.png
                ....
    	    +---scan2_train
    	    +---scan3_train
    	    +---scan4_train
    	    +---scan5_train
    	    +---scan6_train
    	    +---scan7_train
    	    \---scan8_train
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45

    这里举例了8个物体的的数据内容,3个G大小供下载测试,下载链接,其中包含两个测试数据,测试数据目录结构如下:

    测试数据根目录
    +---scan1
    |   +---cams(64个相机内外参,深度范围)
    |   +---cams_1(64个相机内外参,深度范围)
    |   +---images(49张多视角拍摄图片:宽1600*1200|	\---pair.txt(视图之间重合区域匹配文件(1个))
    \---scan4
        +---cams
        +---cams_1
        +---images
        \---pair.txt
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    要转成这种测试数据,可以参考源代码新训练方法说明,测试数据与训练数据不同之处有二:
    1、图片的尺寸变大了(1600x1200);2、不需要深度图,深度图需要使用训练好的模型计算得到,最终产生点云.ply文件.

    作者将所有scan数据划分训练、验证、测试集,并放在lists文件夹中的不同的txt文件中,里面包含了哪些scan用于训练,哪些scan用于测试,目录如下:

    lists
    	├─dtu
    	│      all.txt
    	│      test.txt
    	│      train.txt
    	│      val.txt
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    3 训练PatchMatchNet

    3.1 输入参数

    举例几个重要参数,help处有说明。

    "--trainpath",default="data/mini_dtu/train/", help="训练集的路径"(自定义)
    "--epochs", type=int, default=16, help="训练轮数"(自定义)
    "--batch_size", type=int, default=1, help="训练一批次的大小"(自定义)
    "--loadckpt", default=None, help="加载一个特定的断点文件"(默认无)
    "--parallel", action="store_true", default=False, help="如果设置,使用并行,这可以防止导出TorchScript模型."
    "--patchmatch_iteration", nargs="+", type=int, default=[1, 2, 2], help="patchmatch模块在stages 1,2,3的自迭代次数"
    "--patchmatch_num_sample", nargs="+", type=int, default=[8, 8, 16],help="在stages 1,2,3局部扰动的产生的样本数量"
    "--patchmatch_interval_scale", nargs="+", type=float, default=[0.005, 0.0125, 0.025], help="在逆深度范围内生成局部扰动样本的归一化区间"
    "--patchmatch_range", nargs="+", type=int, default=[6, 4, 2],help="补丁匹配在阶段1,2,3上传播的采样点的固定偏移")
    "--propagate_neighbors", nargs="+", type=int, default=[0, 8, 16],help="自适应传播在阶段1,2,3上的邻居数目"
    "--evaluate_neighbors", nargs="+", type=int, default=[9, 9, 9],help="第1、2、3阶段自适应评价的自适应匹配代价聚合的邻居个数"
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    3.2 制定数据集加载方式

    # dataset, dataloader
    train_dataset = MVSDataset(args.trainpath, args.trainlist, "train", 5, robust_train=True)
    test_dataset = MVSDataset(args.valpath, args.vallist, "val", 5,  robust_train=False)
    
    TrainImgLoader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=8, drop_last=True)
    TestImgLoader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=4, drop_last=False)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    输入:训练集的路径,训练集的train.txt列表,训练模式,待计算的图像数(最多10张邻域图+1张参考图),鲁棒性训练(在10张图中随机选择5张无序的)

    输出:一个元素metas的 5张原图,5张图的内外参数,深度图depth,深度最小大值,深度图可视化png图

    MVSDataset函数的功能:
    1、设定阶段数为4
    2、读取训练集的列表
    3、设置一个空列表metas存放【不同scan,不同光照下的light_idx索引(同一角度共有7种光照不同的图),不同的参考图ref,对应的10张邻域图src集合
    4、获取数据的方法:首先,读取一个metas元素,如果是鲁棒训练,则参考图ref+随机从10张邻域图中选择4张,否则参考图ref+顺序选前4张邻域图

    接着读取数据,

    (1)从Rectified文件夹中读取校正的(宽640x高512)参考图ref和所有src(共5张彩色图,注意参考图的ID是从0-49,对应原图的ID:1-49,故读取原图是ID+1

    (2)从Depths_raw文件夹中读取深度黑白可视化png图(宽160x高128)参考图ref和所有src(共5张彩色图),从Depths_raw文件夹中读取深度pfm图(宽160x高128)参考图ref和所有src(共5张彩色图),这两个图的ID跟参考图ID一样0-48,故不需要加1。

    (3) 从Cameras_1文件夹中读取5张不同视角下的相机内外参数和ref图的深度范围。

    (4)此时读取的是宽160x高128的图片的相机内参,而现在需要更大尺寸的图片对应的内参,故需要升高相机内参,这里放大了原来的4倍。

    #共六张图的内、外参
    intrinsic[:2, :] *= 4.0
    intrinsics.append(intrinsic)
    extrinsics.append(extrinsic)
    
    • 1
    • 2
    • 3
    • 4

    (5)读取【参考视图ref的深度范围,深度图可视化的png图(宽512x高640),pfm深度图(宽512x高640)】
    注意:为了训练计算损失时,原图和标签图的尺寸要保持一致不然一直报错,如标签图宽160x高128转成原彩色图宽512x高640一样大小的尺寸。这里的处理方法是先将标签图调整成宽1600x高1200,再按照源代码prepare_img,最终生成宽512x高640的深度图

    def prepare_img(hr_img: np.ndarray) -> np.ndarray:
        # original w,h: 1600, 1200; downsample -> 800, 600 ; crop -> 640, 512
        # opencv使用:python只需要img = cv.resize(img,(width,height),interpolation=cv.INTER_LINEAR) 一行代码即可
        # 注意hr_img:shape是H,W,C
        # downsample
        hr_img = cv2.resize(hr_img, (1600, 1200), interpolation=cv2.INTER_NEAREST)  #增加调整尺寸,数据尺寸对齐
        h, w = hr_img.shape
        hr_img_ds = cv2.resize(hr_img, (w // 2, h // 2), interpolation=cv2.INTER_NEAREST)
        # crop
        h, w = hr_img_ds.shape
        target_h, target_w = 512, 640
        start_h, start_w = (h - target_h) // 2, (w - target_w) // 2
        hr_img_crop = hr_img_ds[start_h: start_h + target_h, start_w: start_w + target_w]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    3.2 训练样本

    当加载完一批次的数据后,使用train_sample方法进行训练.
    输入:一批次样本(比如batch_size=1,则只有一个metas元素包含1张ref原图和4张src图的数据。

    (1)首先,创建4个阶段的深度图depth_gt和标签mask图
    使用最近邻插值,分别返回:原图,1/2原图,1/4原图,1/8原图

    def create_stage_images(image: torch.Tensor) -> List[torch.Tensor]:
        return [
            image,
            F.interpolate(image, scale_factor=0.5, mode="nearest"),
            F.interpolate(image, scale_factor=0.25, mode="nearest"),
            F.interpolate(image, scale_factor=0.125, mode="nearest")
        ]
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    (2)将【5张原图,5张图的内参,5张图的外参,ref深度图的最小、最大值】输入到模型model中,model是在net文件中

    class PatchmatchNet(nn.Module):
        """ Implementation of complete structure of PatchmatchNet"""
    
        def __init__(
            self,
            patchmatch_interval_scale: List[float],
            propagation_range: List[int],
            patchmatch_iteration: List[int],
            patchmatch_num_sample: List[int],
            propagate_neighbors: List[int],
            evaluate_neighbors: List[int],
        ) -> None:
            """Initialize modules in PatchmatchNet
    
            Args:
                patchmatch_interval_scale: depth interval scale in patchmatch module
                propagation_range: propagation range
                patchmatch_iteration: patchmatch iteration number
                patchmatch_num_sample: patchmatch number of samples
                propagate_neighbors: number of propagation neighbors
                evaluate_neighbors: number of propagation neighbors for evaluation
            """
            super(PatchmatchNet, self).__init__()
    
            self.stages = 4
            self.feature = FeatureNet()
            self.patchmatch_num_sample = patchmatch_num_sample
    
            num_features = [16, 32, 64]
    
            self.propagate_neighbors = propagate_neighbors
            self.evaluate_neighbors = evaluate_neighbors
            # number of groups for group-wise correlation
            self.G = [4, 8, 8]
    
            for i in range(self.stages - 1):
                patchmatch = PatchMatch(
                    propagation_out_range=propagation_range[i],
                    patchmatch_iteration=patchmatch_iteration[i],
                    patchmatch_num_sample=patchmatch_num_sample[i],
                    patchmatch_interval_scale=patchmatch_interval_scale[i],
                    num_feature=num_features[i],
                    G=self.G[i],
                    propagate_neighbors=self.propagate_neighbors[i],
                    evaluate_neighbors=evaluate_neighbors[i],
                    stage=i + 1,
                )
                setattr(self, f"patchmatch_{i+1}", patchmatch)
    
            self.upsample_net = Refinement()
    
        def forward(
            self,
            images: List[torch.Tensor],
            intrinsics: torch.Tensor,
            extrinsics: torch.Tensor,
            depth_min: torch.Tensor,
            depth_max: torch.Tensor,
        ) -> Tuple[torch.Tensor, torch.Tensor, Dict[int, List[torch.Tensor]]]:
            """Forward method for PatchMatchNet
    
            Args:
                images: N images (B, 3, H, W) stored in list
                intrinsics: intrinsic 3x3 matrices for all images (B, N, 3, 3)
                extrinsics: extrinsic 4x4 matrices for all images (B, N, 4, 4)
                depth_min: minimum virtual depth (B, 1)
                depth_max: maximum virtual depth (B, 1)
    
            Returns:
                output tuple of PatchMatchNet, containing refined depthmap, depth patchmatch, and photometric confidence.
            """
            assert len(images) == intrinsics.size()[1], "Different number of images and intrinsic matrices"
            assert len(images) == extrinsics.size()[1], 'Different number of images and extrinsic matrices'
            images, intrinsics, orig_height, orig_width = adjust_image_dims(images, intrinsics)
            ref_image = images[0]
            _, _, ref_height, ref_width = ref_image.size()
    
            # step 1. Multi-scale feature extraction
            features: List[Dict[int, torch.Tensor]] = []
            for img in images:
                output_feature = self.feature(img)
                features.append(output_feature)
            del images
            ref_feature, src_features = features[0], features[1:]
    
            depth_min = depth_min.float()
            depth_max = depth_max.float()
    
            # step 2. Learning-based patchmatch
            device = intrinsics.device
            depth = torch.empty(0, device=device)
            depths: List[torch.Tensor] = []
            score = torch.empty(0, device=device)
            view_weights = torch.empty(0, device=device)
            depth_patchmatch: Dict[int, List[torch.Tensor]] = {}
    
            scale = 0.125
            for stage in range(self.stages - 1, 0, -1):
                src_features_l = [src_fea[stage] for src_fea in src_features]
    
                # Create projection matrix for specific stage
                intrinsics_l = intrinsics.clone()
                intrinsics_l[:, :, :2] *= scale
                proj = extrinsics.clone()
                proj[:, :, :3, :4] = torch.matmul(intrinsics_l, extrinsics[:, :, :3, :4])
                proj_l = torch.unbind(proj, 1)
                ref_proj, src_proj = proj_l[0], proj_l[1:]
                scale *= 2.0
    
                # Need conditional since TorchScript only allows "getattr" access with string literals
                if stage == 3:
                    depths, score, view_weights = self.patchmatch_3(
                        ref_feature=ref_feature[stage],
                        src_features=src_features_l,
                        ref_proj=ref_proj,
                        src_projs=src_proj,
                        depth_min=depth_min,
                        depth_max=depth_max,
                        depth=depth,
                        view_weights=view_weights,
                    )
                elif stage == 2:
                    depths, score, view_weights = self.patchmatch_2(
                        ref_feature=ref_feature[stage],
                        src_features=src_features_l,
                        ref_proj=ref_proj,
                        src_projs=src_proj,
                        depth_min=depth_min,
                        depth_max=depth_max,
                        depth=depth,
                        view_weights=view_weights,
                    )
                elif stage == 1:
                    depths, score, view_weights = self.patchmatch_1(
                        ref_feature=ref_feature[stage],
                        src_features=src_features_l,
                        ref_proj=ref_proj,
                        src_projs=src_proj,
                        depth_min=depth_min,
                        depth_max=depth_max,
                        depth=depth,
                        view_weights=view_weights,
                    )
    
                depth_patchmatch[stage] = depths
                depth = depths[-1].detach()
    
                if stage > 1:
                    # upsampling the depth map and pixel-wise view weight for next stage
                    depth = F.interpolate(depth, scale_factor=2.0, mode="nearest")
                    view_weights = F.interpolate(view_weights, scale_factor=2.0, mode="nearest")
    
            del ref_feature
            del src_features
    
            # step 3. Refinement
            depth = self.upsample_net(ref_image, depth, depth_min, depth_max)
            if ref_width != orig_width or ref_height != orig_height:
                depth = F.interpolate(depth, size=[orig_height, orig_width], mode='bilinear', align_corners=False)
            depth_patchmatch[0] = [depth]
    
            if self.training:
                return depth, torch.empty(0, device=device), depth_patchmatch
            else:
                num_depth = self.patchmatch_num_sample[0]
                score_sum4 = 4 * F.avg_pool3d(
                    F.pad(score.unsqueeze(1), pad=(0, 0, 0, 0, 1, 2)), (4, 1, 1), stride=1, padding=0
                ).squeeze(1)
                # [B, 1, H, W]
                depth_index = depth_regression(
                    score, depth_values=torch.arange(num_depth, device=score.device, dtype=torch.float)
                ).long().clamp(0, num_depth - 1)
                photometric_confidence = torch.gather(score_sum4, 1, depth_index)
                photometric_confidence = F.interpolate(
                    photometric_confidence, size=[orig_height, orig_width], mode="nearest").squeeze(1)
    
                return depth, photometric_confidence, depth_patchmatch
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177

    在网络中需要将输入的5张图稍微拉伸或压缩图像,以确保宽度和高度是8的倍数,有助于网络的阶段尺寸成倍变化。

    def adjust_image_dims(
            images: List[torch.Tensor], intrinsics: torch.Tensor) -> Tuple[List[torch.Tensor], torch.Tensor, int, int]:
        # stretch or compress image slightly to ensure width and height are multiples of 8
        _, _, ref_height, ref_width = images[0].size()
        for i in range(len(images)):
            _, _, height, width = images[i].size()
            new_height = int(round(height / 8)) * 8
            new_width = int(round(width / 8)) * 8
            if new_width != width or new_height != height:
                intrinsics[:, i, 0] *= new_width / width	#原图的内参和外参也进行了相应的更改
                intrinsics[:, i, 1] *= new_height / height
                images[i] = nn.functional.interpolate(
                    images[i], size=[new_height, new_width], mode='bilinear', align_corners=False)	# 这里的差值size是先高后宽与OpenCV不一样
    
        return images, intrinsics, ref_height, ref_width
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    (3)当5张图片的尺寸调整成8的倍数后,5张原图的内参也进行了相应的更改。新产生的5张原图输入特征提取网络FeatureNet中:
    在这里插入图片描述

    class FeatureNet(nn.Module):
        """Feature Extraction Network: to extract features of original images from each view"""
    
        def __init__(self):
            """Initialize different layers in the network"""
    
            super(FeatureNet, self).__init__()
    
            self.conv0 = ConvBnReLU(3, 8, 3, 1, 1)
            # [B,8,H,W]
            self.conv1 = ConvBnReLU(8, 8, 3, 1, 1)
            # [B,16,H/2,W/2]
            self.conv2 = ConvBnReLU(8, 16, 5, 2, 2)
            self.conv3 = ConvBnReLU(16, 16, 3, 1, 1)
            self.conv4 = ConvBnReLU(16, 16, 3, 1, 1)
            # [B,32,H/4,W/4]
            self.conv5 = ConvBnReLU(16, 32, 5, 2, 2)
            self.conv6 = ConvBnReLU(32, 32, 3, 1, 1)
            self.conv7 = ConvBnReLU(32, 32, 3, 1, 1)
            # [B,64,H/8,W/8]
            self.conv8 = ConvBnReLU(32, 64, 5, 2, 2)
            self.conv9 = ConvBnReLU(64, 64, 3, 1, 1)
            self.conv10 = ConvBnReLU(64, 64, 3, 1, 1)
    
            self.output1 = nn.Conv2d(64, 64, 1, bias=False)
            self.inner1 = nn.Conv2d(32, 64, 1, bias=True)
            self.inner2 = nn.Conv2d(16, 64, 1, bias=True)
            self.output2 = nn.Conv2d(64, 32, 1, bias=False)
            self.output3 = nn.Conv2d(64, 16, 1, bias=False)
    
        def forward(self, x: torch.Tensor) -> Dict[int, torch.Tensor]:
            """Forward method
    
            Args:
                x: images from a single view, in the shape of [B, C, H, W]. Generally, C=3
    
            Returns:
                output_feature: a python dictionary contains extracted features from stage 1 to stage 3
                    keys are 1, 2, and 3
            """
            output_feature: Dict[int, torch.Tensor] = {}
    
            conv1 = self.conv1(self.conv0(x))
            conv4 = self.conv4(self.conv3(self.conv2(conv1)))
    
            conv7 = self.conv7(self.conv6(self.conv5(conv4)))
            conv10 = self.conv10(self.conv9(self.conv8(conv7)))
    
            output_feature[3] = self.output1(conv10)
            intra_feat = F.interpolate(conv10, scale_factor=2.0, mode="bilinear", align_corners=False) + self.inner1(conv7)
            del conv7
            del conv10
    
            output_feature[2] = self.output2(intra_feat)
            intra_feat = F.interpolate(
                intra_feat, scale_factor=2.0, mode="bilinear", align_corners=False) + self.inner2(conv4)
            del conv4
    
            output_feature[1] = self.output3(intra_feat)
            del intra_feat
    
            return output_feature
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62

    FeatureNet由10个基础卷积块ConvBnReLU和5个卷积层Conv2d组成,ConvBnReLU由卷积层+批归一化+ReLU激化函数组成,如下:

    class ConvBnReLU(nn.Module):
        """Implements 2d Convolution + batch normalization + ReLU"""
    
        def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: int = 3,
            stride: int = 1,
            pad: int = 1,
            dilation: int = 1,
        ) -> None:
            """initialization method for convolution2D + batch normalization + relu module
            Args:
                in_channels: input channel number of convolution layer
                out_channels: output channel number of convolution layer
                kernel_size: kernel size of convolution layer
                stride: stride of convolution layer
                pad: pad of convolution layer
                dilation: dilation of convolution layer
            """
            super(ConvBnReLU, self).__init__()
            self.conv = nn.Conv2d(
                in_channels, out_channels, kernel_size, stride=stride, padding=pad, dilation=dilation, bias=False
            )
            self.bn = nn.BatchNorm2d(out_channels)
    
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            """forward method"""
            return F.relu(self.bn(self.conv(x)), inplace=True)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30

    FeatureNet在第10个基础卷积块结束后 ,原来x输入的形状torch.size([1,3,512,640])变成conv10的形状为torch.size([1,64,64,80]),特征图尺寸进行了压缩,再经过一次卷积Conv2d,形状未变还是torch.size([1,64,64,80]),将结果存入字典output_feature[3]中

    接着conv10进过双线性插值发扩大成原来的2被尺寸,形状为torch.size([1,64,128,160]),再将conv7(torch.size([1,32,128,160]))进行一次卷积Conv2d变成形状为torch.size([1,64,128,160]),于是将conv10插值后+conv7卷积后的数据相加成intra_feat,其形状为torch.size([1,64,128,160]),相当于特征进行了融合,形状未变,并清理conv7、conv10来减少内存。

    然后将融合后的特征图intra_feat再次卷积Conv2d变成形状为torch.size([1,32,128,160])的特征图,并存放在字典output_feature[2] 中,再将conv4卷积和intra_feat插值并相加成新的intra_feat,形状为torch.size([1,64,256,320]),特征图变为原来的2倍,并清理conv4来减少内存。

    最后将融合后的特征图intra_feat再次卷积Conv2d变成形状为torch.size([1,16,256,320])的特征图,并存放在字典output_feature[1] 中,特征图尺寸变为原来的2倍,通道数减少。最终返回特征提取后的字典output_feature

    5张原图特征提取结束后,output_feature中的3个元素存放在列表features中(每一张图有3个特征输出表示3个阶段的输出),此时删除原图以减少内存。

    (4)基于补丁匹配的学习方法
    在这里插入图片描述
    当特征提取结束后,接着进行PatchMatch Learning,基于补丁匹配的详细结构如下:
    在这里插入图片描述

    		# step 2. Learning-based patchmatch
            device = intrinsics.device
            depth = torch.empty(0, device=device)
            depths: List[torch.Tensor] = []
            score = torch.empty(0, device=device)
            view_weights = torch.empty(0, device=device)
            depth_patchmatch: Dict[int, List[torch.Tensor]] = {}
    
            scale = 0.125
            for stage in range(self.stages - 1, 0, -1):
                src_features_l = [src_fea[stage] for src_fea in src_features]
    
                # Create projection matrix for specific stage
                intrinsics_l = intrinsics.clone()
                intrinsics_l[:, :, :2] *= scale
                proj = extrinsics.clone()
                proj[:, :, :3, :4] = torch.matmul(intrinsics_l, extrinsics[:, :, :3, :4])
                proj_l = torch.unbind(proj, 1)
                ref_proj, src_proj = proj_l[0], proj_l[1:]
                scale *= 2.0
    
                # Need conditional since TorchScript only allows "getattr" access with string literals
                if stage == 3:
                    depths, score, view_weights = self.patchmatch_3(
                        ref_feature=ref_feature[stage],
                        src_features=src_features_l,
                        ref_proj=ref_proj,
                        src_projs=src_proj,
                        depth_min=depth_min,
                        depth_max=depth_max,
                        depth=depth,
                        view_weights=view_weights,
                    )
                elif stage == 2:
                    depths, score, view_weights = self.patchmatch_2(
                        ref_feature=ref_feature[stage],
                        src_features=src_features_l,
                        ref_proj=ref_proj,
                        src_projs=src_proj,
                        depth_min=depth_min,
                        depth_max=depth_max,
                        depth=depth,
                        view_weights=view_weights,
                    )
                elif stage == 1:
                    depths, score, view_weights = self.patchmatch_1(
                        ref_feature=ref_feature[stage],
                        src_features=src_features_l,
                        ref_proj=ref_proj,
                        src_projs=src_proj,
                        depth_min=depth_min,
                        depth_max=depth_max,
                        depth=depth,
                        view_weights=view_weights,
                    )
    
                depth_patchmatch[stage] = depths
                depth = depths[-1].detach()
    
                if stage > 1:
                    # upsampling the depth map and pixel-wise view weight for next stage
                    depth = F.interpolate(depth, scale_factor=2.0, mode="nearest")
                    view_weights = F.interpolate(view_weights, scale_factor=2.0, mode="nearest")
    
            del ref_feature
            del src_features
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66

    先从阶段3开始PatchMatch Learning,此时的特征图尺寸最小,形状为torch.size([1,64,64,80])。将4个src视图的三阶段特征提取到src_features_l中,此时的内参需要变为原来的1/8,在新的内参相应的3x3部分乘以外参相应的3x4部分得到5个投影矩阵proj的3x4部分,proj复制外参的形状为torch.size([1,5,4,4])。

    在阶段3时,将ref特征图和src特征图,ref投影矩阵,src投影矩阵,深度最小、最大值,输入到patchmatch方法中,此时propagate_neighbors=16,patchmatch_iteration=2,ref特征图传入自适应传播卷积中等到传播偏置propa_offset:

    		# 自适应传播:阶段1的最后一次迭代没有传播,
    		# 但是我们仍然为TorchScript导出兼容性定义了传播
            self.propa_conv = nn.Conv2d(
                in_channels=self.propa_num_feature,
                out_channels=max(2 * self.propagate_neighbors, 1),
                kernel_size=3,
                stride=1,
                padding=self.dilation,
                dilation=self.dilation,
                bias=True,
            )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    ref形状由原来的torch.size([1,64,64,80])变为torch.size([1,32,64,80]),再改变形状为(batch,2 * self.propagate_neighbors, height * width),得到propa_offset形状为torch.size([1,32,5120])。

    将propa_offset带入get_grid方法计算自适应传播的偏移量或自适应评估中空间成本聚合的偏移量:

    def get_grid(
            self, grid_type: int, batch: int, height: int, width: int, offset: torch.Tensor, device: torch.device
        ) -> torch.Tensor:
            """Compute the offset for adaptive propagation or spatial cost aggregation in adaptive evaluation
    
            Args:
                grid_type: type of grid - propagation (1) or evaluation (2)
                batch: batch size
                height: grid height
                width: grid width
                offset: grid offset
                device: device on which to place tensor
    
            Returns:
                generated grid: in the shape of [batch, propagate_neighbors*H, W, 2]
            """
    
            if grid_type == self.grid_type["propagation"]:
                if self.propagate_neighbors == 4:  # if 4 neighbors to be sampled in propagation
                    original_offset = [[-self.dilation, 0], [0, -self.dilation], [0, self.dilation], [self.dilation, 0]]
                elif self.propagate_neighbors == 8:  # if 8 neighbors to be sampled in propagation
                    original_offset = [
                        [-self.dilation, -self.dilation],
                        [-self.dilation, 0],
                        [-self.dilation, self.dilation],
                        [0, -self.dilation],
                        [0, self.dilation],
                        [self.dilation, -self.dilation],
                        [self.dilation, 0],
                        [self.dilation, self.dilation],
                    ]
                elif self.propagate_neighbors == 16:  # if 16 neighbors to be sampled in propagation
                    original_offset = [
                        [-self.dilation, -self.dilation],
                        [-self.dilation, 0],
                        [-self.dilation, self.dilation],
                        [0, -self.dilation],
                        [0, self.dilation],
                        [self.dilation, -self.dilation],
                        [self.dilation, 0],
                        [self.dilation, self.dilation],
                    ]
                    for i in range(len(original_offset)):
                        offset_x, offset_y = original_offset[i]
                        original_offset.append([2 * offset_x, 2 * offset_y])
                else:
                    raise NotImplementedError
            elif grid_type == self.grid_type["evaluation"]:
                dilation = self.dilation - 1  # dilation of evaluation is a little smaller than propagation
                if self.evaluate_neighbors == 9:  # if 9 neighbors to be sampled in evaluation
                    original_offset = [
                        [-dilation, -dilation],
                        [-dilation, 0],
                        [-dilation, dilation],
                        [0, -dilation],
                        [0, 0],
                        [0, dilation],
                        [dilation, -dilation],
                        [dilation, 0],
                        [dilation, dilation],
                    ]
                elif self.evaluate_neighbors == 17:  # if 17 neighbors to be sampled in evaluation
                    original_offset = [
                        [-dilation, -dilation],
                        [-dilation, 0],
                        [-dilation, dilation],
                        [0, -dilation],
                        [0, 0],
                        [0, dilation],
                        [dilation, -dilation],
                        [dilation, 0],
                        [dilation, dilation],
                    ]
                    for i in range(len(original_offset)):
                        offset_x, offset_y = original_offset[i]
                        if offset_x != 0 or offset_y != 0:
                            original_offset.append([2 * offset_x, 2 * offset_y])
                else:
                    raise NotImplementedError
            else:
                raise NotImplementedError
    
            with torch.no_grad():
                y_grid, x_grid = torch.meshgrid(
                    [
                        torch.arange(0, height, dtype=torch.float32, device=device),
                        torch.arange(0, width, dtype=torch.float32, device=device),
                    ]
                )
                y_grid, x_grid = y_grid.contiguous().view(height * width), x_grid.contiguous().view(height * width)
                xy = torch.stack((x_grid, y_grid))  # [2, H*W]
                xy = torch.unsqueeze(xy, 0).repeat(batch, 1, 1)  # [B, 2, H*W]
    
            xy_list = []
            for i in range(len(original_offset)):
                original_offset_y, original_offset_x = original_offset[i]
                offset_x = original_offset_x + offset[:, 2 * i, :].unsqueeze(1)
                offset_y = original_offset_y + offset[:, 2 * i + 1, :].unsqueeze(1)
                xy_list.append((xy + torch.cat((offset_x, offset_y), dim=1)).unsqueeze(2))
    
            xy = torch.cat(xy_list, dim=2)  # [B, 2, 9, H*W]
    
            del xy_list
            del x_grid
            del y_grid
    
            x_normalized = xy[:, 0, :, :] / ((width - 1) / 2) - 1
            y_normalized = xy[:, 1, :, :] / ((height - 1) / 2) - 1
            del xy
            grid = torch.stack((x_normalized, y_normalized), dim=3)  # [B, 9, H*W, 2]
            del x_normalized
            del y_normalized
            return grid.view(batch, len(original_offset) * height, width, 2)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113

    获得propa_grid的形状为torch.size([1,1024,80,2]),接着计算学习到的用于自适应空间成本聚合的附加2D偏移量(自适应评估)eval_offset形状为torch.size([1,18,5120])和eval_grid形状为torch.size([1,576,80,2])。

    将获得的eval_grid与ref特征图传入self.feature_weight_net,获得feature_weight形状为torch.size([1,9,64,80])。

    接着,进行self.patchmatch_iteration次迭代,迭代过程中需要深度初始化,输入
    深度的最小最大值,深度图的宽高,深度间隔尺度,当前深度,进过计算,返回采样深度:

    class DepthInitialization(nn.Module):
        """Initialization Stage Class"""
    
        def __init__(self, patchmatch_num_sample: int = 1) -> None:
            """Initialize method
    
            Args:
                patchmatch_num_sample: number of samples used in patchmatch process
            """
            super(DepthInitialization, self).__init__()
            self.patchmatch_num_sample = patchmatch_num_sample
    
        def forward(
            self,
            min_depth: torch.Tensor,
            max_depth: torch.Tensor,
            height: int,
            width: int,
            depth_interval_scale: float,
            device: torch.device,
            depth: torch.Tensor
        ) -> torch.Tensor:
            """Forward function for depth initialization
    
            Args:
                min_depth: minimum virtual depth, (B, )
                max_depth: maximum virtual depth, (B, )
                height: height of depth map
                width: width of depth map
                depth_interval_scale: depth interval scale
                device: device on which to place tensor
                depth: current depth (B, 1, H, W)
    
            Returns:
                depth_sample: initialized sample depth map by randomization or local perturbation (B, Ndepth, H, W)
            """
            batch_size = min_depth.size()[0]
            inverse_min_depth = 1.0 / min_depth
            inverse_max_depth = 1.0 / max_depth
            if is_empty(depth):
                # first iteration of Patchmatch on stage 3, sample in the inverse depth range
                # divide the range into several intervals and sample in each of them
                patchmatch_num_sample = 48
                # [B,Ndepth,H,W]
                depth_sample = torch.rand(
                    size=(batch_size, patchmatch_num_sample, height, width), device=device
                ) + torch.arange(start=0, end=patchmatch_num_sample, step=1, device=device).view(
                    1, patchmatch_num_sample, 1, 1
                )
    
                depth_sample = inverse_max_depth.view(batch_size, 1, 1, 1) + depth_sample / patchmatch_num_sample * (
                    inverse_min_depth.view(batch_size, 1, 1, 1) - inverse_max_depth.view(batch_size, 1, 1, 1)
                )
    
                return 1.0 / depth_sample
    
            elif self.patchmatch_num_sample == 1:
                return depth.detach()
            else:
                # other Patchmatch, local perturbation is performed based on previous result
                # uniform samples in an inversed depth range
                depth_sample = (
                    torch.arange(-self.patchmatch_num_sample // 2, self.patchmatch_num_sample // 2, 1, device=device)
                    .view(1, self.patchmatch_num_sample, 1, 1).repeat(batch_size, 1, height, width).float()
                )
                inverse_depth_interval = (inverse_min_depth - inverse_max_depth) * depth_interval_scale
                inverse_depth_interval = inverse_depth_interval.view(batch_size, 1, 1, 1)
    
                depth_sample = 1.0 / depth.detach() + inverse_depth_interval * depth_sample
    
                depth_clamped = []
                del depth
                for k in range(batch_size):
                    depth_clamped.append(
                        torch.clamp(depth_sample[k], min=inverse_max_depth[k], max=inverse_min_depth[k]).unsqueeze(0)
                    )
    
                return 1.0 / torch.cat(depth_clamped, dim=0)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78

    深度采样之后需要自适应传播,不过第一阶段的最后一次迭代不需要自适应传播,
    输入采样深度和传播网格propa_grid,返回自适应后的采样深度:

    class Propagation(nn.Module):
        """ Propagation module implementation"""
    
        def __init__(self) -> None:
            """Initialize method"""
            super(Propagation, self).__init__()
    
        def forward(self, depth_sample: torch.Tensor, grid: torch.Tensor) -> torch.Tensor:
            # [B,D,H,W]
            """Forward method of adaptive propagation
    
            Args:
                depth_sample: sample depth map, in shape of [batch, num_depth, height, width],
                grid: 2D grid for bilinear gridding, in shape of [batch, neighbors*H, W, 2]
    
            Returns:
                propagate depth: sorted propagate depth map [batch, num_depth+num_neighbors, height, width]
            """
            batch, num_depth, height, width = depth_sample.size()
            num_neighbors = grid.size()[1] // height
            propagate_depth_sample = F.grid_sample(
                depth_sample[:, num_depth // 2, :, :].unsqueeze(1),
                grid,
                mode="bilinear",
                padding_mode="border",
                align_corners=False
            ).view(batch, num_neighbors, height, width)
            return torch.sort(torch.cat((depth_sample, propagate_depth_sample), dim=1), dim=1)[0]
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29

    接着,自适应评价中自适应空间代价聚合的权重,输入:采样深度,最小最大深度,eval_grid,间隔尺寸,评估邻居个数,返回深度权重:

    def depth_weight(
        depth_sample: torch.Tensor,
        depth_min: torch.Tensor,
        depth_max: torch.Tensor,
        grid: torch.Tensor,
        patchmatch_interval_scale: float,
        neighbors: int,
    ) -> torch.Tensor:
        """Calculate depth weight
        1. Adaptive spatial cost aggregation
        2. Weight based on depth difference of sampling points and center pixel
    
        Args:
            depth_sample: sample depth map, (B,Ndepth,H,W)
            depth_min: minimum virtual depth, (B,)
            depth_max: maximum virtual depth, (B,)
            grid: position of sampling points in adaptive spatial cost aggregation, (B, neighbors*H, W, 2)
            patchmatch_interval_scale: patchmatch interval scale,
            neighbors: number of neighbors to be sampled in evaluation
    
        Returns:
            depth weight
        """
        batch, num_depth, height, width = depth_sample.size()
        inverse_depth_min = 1.0 / depth_min
        inverse_depth_max = 1.0 / depth_max
    
        # normalization
        x = 1.0 / depth_sample
        del depth_sample
        x = (x - inverse_depth_max.view(batch, 1, 1, 1)) / (inverse_depth_min - inverse_depth_max).view(batch, 1, 1, 1)
    
        x1 = F.grid_sample(
            x, grid, mode="bilinear", padding_mode="border", align_corners=False
        ).view(batch, num_depth, neighbors, height, width)
        del grid
    
        # [B,Ndepth,N_neighbors,H,W]
        x1 = torch.abs(x1 - x.unsqueeze(2)) / patchmatch_interval_scale
        del x
    
        # sigmoid output approximate to 1 when x=4
        return torch.sigmoid(4.0 - 2.0 * x1.clamp(min=0, max=4)).detach()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43

    将返回的深度权重乘以前面的feature_weight得到新的权重weight,接着进行评估,输出回归深度图和像素级视图权重,这将用于后续迭代。

    评估需要输入:ref特征图,src特征图,ref投影矩阵,src投影矩阵,采样深度,评估网格eval_grid,权重weight,视图权重view_weights,是否翻转,返回深度采样期望值,分值,视图权重view_weights:

    class Evaluation(nn.Module):
        """Evaluation module for adaptive evaluation step in Learning-based Patchmatch
        Used to compute the matching costs for all the hypotheses and choose best solutions.
        """
    
        def __init__(self, G: int = 8) -> None:
            """Initialize method`
    
            Args:
                G: the feature channels of input will be divided evenly into G groups
            """
            super(Evaluation, self).__init__()
    
            self.G = G
            self.pixel_wise_net = PixelwiseNet(self.G)
            self.softmax = nn.LogSoftmax(dim=1)
            self.similarity_net = SimilarityNet(self.G)
    
        def forward(
            self,
            ref_feature: torch.Tensor,
            src_features: List[torch.Tensor],
            ref_proj: torch.Tensor,
            src_projs: List[torch.Tensor],
            depth_sample: torch.Tensor,
            grid: torch.Tensor,
            weight: torch.Tensor,
            view_weights: torch.Tensor,
            is_inverse: bool
        ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            """Forward method for adaptive evaluation
    
            Args:
                ref_feature: feature from reference view, (B, C, H, W)
                src_features: features from (Nview-1) source views, (Nview-1) * (B, C, H, W), where Nview is the number of
                    input images (or views) of PatchmatchNet
                ref_proj: projection matrix of reference view, (B, 4, 4)
                src_projs: source matrices of source views, (Nview-1) * (B, 4, 4), where Nview is the number of input
                    images (or views) of PatchmatchNet
                depth_sample: sample depth map, (B,Ndepth,H,W)
                grid: grid, (B, evaluate_neighbors*H, W, 2)
                weight: weight, (B,Ndepth,1,H,W)
                view_weights: Tensor to store weights of source views, in shape of (B,Nview-1,H,W),
                    Nview-1 represents the number of source views
                is_inverse: Flag for inverse depth regression
    
            Returns:
                depth_sample: expectation of depth sample, (B,H,W)
                score: probability map, (B,Ndepth,H,W)
                view_weights: optional, Tensor to store weights of source views, in shape of (B,Nview-1,H,W),
                    Nview-1 represents the number of source views
            """
            batch, feature_channel, height, width = ref_feature.size()
            device = ref_feature.device
    
            num_depth = depth_sample.size()[1]
            assert (
                len(src_features) == len(src_projs)
            ), "Patchmatch Evaluation: Different number of images and projection matrices"
            if not is_empty(view_weights):
                assert (
                    len(src_features) == view_weights.size()[1]
                ), "Patchmatch Evaluation: Different number of images and view weights"
    
            # Change to a tensor with value 1e-5
            pixel_wise_weight_sum = 1e-5 * torch.ones((batch, 1, 1, height, width), dtype=torch.float32, device=device)
            ref_feature = ref_feature.view(batch, self.G, feature_channel // self.G, 1, height, width)
            similarity_sum = torch.zeros((batch, self.G, num_depth, height, width), dtype=torch.float32, device=device)
    
            i = 0
            view_weights_list = []
            for src_feature, src_proj in zip(src_features, src_projs):
                warped_feature = differentiable_warping(
                    src_feature, src_proj, ref_proj, depth_sample
                ).view(batch, self.G, feature_channel // self.G, num_depth, height, width)
                # group-wise correlation
                similarity = (warped_feature * ref_feature).mean(2)
                # pixel-wise view weight
                if is_empty(view_weights):
                    view_weight = self.pixel_wise_net(similarity)
                    view_weights_list.append(view_weight)
                else:
                    # reuse the pixel-wise view weight from first iteration of Patchmatch on stage 3
                    view_weight = view_weights[:, i].unsqueeze(1)  # [B,1,H,W]
                    i = i + 1
    
                similarity_sum += similarity * view_weight.unsqueeze(1)
                pixel_wise_weight_sum += view_weight.unsqueeze(1)
    
            # aggregated matching cost across all the source views
            similarity = similarity_sum.div_(pixel_wise_weight_sum)  # [B, G, Ndepth, H, W]
            # adaptive spatial cost aggregation
            score = self.similarity_net(similarity, grid, weight)  # [B, G, Ndepth, H, W]
            # apply softmax to get probability
            score = torch.exp(self.softmax(score))
    
            if is_empty(view_weights):
                view_weights = torch.cat(view_weights_list, dim=1)  # [B,4,H,W], 4 is the number of source views
    
            if is_inverse:
                # depth regression: inverse depth regression
                depth_index = torch.arange(0, num_depth, 1, device=device).view(1, num_depth, 1, 1)
                depth_index = torch.sum(depth_index * score, dim=1)
    
                inverse_min_depth = 1.0 / depth_sample[:, -1, :, :]
                inverse_max_depth = 1.0 / depth_sample[:, 0, :, :]
                depth_sample = inverse_max_depth + depth_index / (num_depth - 1) * (inverse_min_depth - inverse_max_depth)
                depth_sample = 1.0 / depth_sample
            else:
                # depth regression: expectation
                depth_sample = torch.sum(depth_sample * score, dim=1)
    
            return depth_sample, score, view_weights.detach()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113

    其中使用了基于可微的翰墨格拉菲翘曲,获得翘曲后的ref特征图,再计算ref特征图与翘曲后的ref特征图的相似性,将相似性传入self.pixel_wise_net中获得视图全权重view_weight:

    class PixelwiseNet(nn.Module):
        """Pixelwise Net: A simple pixel-wise view weight network, composed of 1x1x1 convolution layers
        and sigmoid nonlinearities, takes the initial set of similarities to output a number between 0 and 1 per
        pixel as estimated pixel-wise view weight.
    
        1. The Pixelwise Net is used in adaptive evaluation step
        2. The similarity is calculated by ref_feature and other source_features warped by differentiable_warping
        3. The learned pixel-wise view weight is estimated in the first iteration of Patchmatch and kept fixed in the
        matching cost computation.
        """
    
        def __init__(self, G: int) -> None:
            """Initialize method
    
            Args:
                G: the feature channels of input will be divided evenly into G groups
            """
            super(PixelwiseNet, self).__init__()
            self.conv0 = ConvBnReLU3D(in_channels=G, out_channels=16, kernel_size=1, stride=1, pad=0)
            self.conv1 = ConvBnReLU3D(in_channels=16, out_channels=8, kernel_size=1, stride=1, pad=0)
            self.conv2 = nn.Conv3d(in_channels=8, out_channels=1, kernel_size=1, stride=1, padding=0)
            self.output = nn.Sigmoid()
    
        def forward(self, x1: torch.Tensor) -> torch.Tensor:
            """Forward method for PixelwiseNet
    
            Args:
                x1: pixel-wise view weight, [B, G, Ndepth, H, W], where G is the number of groups
            """
            # [B,1,H,W]
            return torch.max(self.output(self.conv2(self.conv1(self.conv0(x1))).squeeze(1)), dim=1)[0].unsqueeze(1)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31

    将相似性乘以视图权重加上相似性之和获得新的相似性之和,再将像素级权重之和加上视图权重作为新的像素级权重之和,新的相似性之和除以新的像素级权重之和得到所有src特征图的相似性,将相似性、grid、weight传入self.similarity_net中计算每张深度图的分值:

    class SimilarityNet(nn.Module):
        """Similarity Net, used in Evaluation module (adaptive evaluation step)
        1. Do 1x1x1 convolution on aggregated cost [B, G, Ndepth, H, W] among all the source views,
            where G is the number of groups
        2. Perform adaptive spatial cost aggregation to get final cost (scores)
        """
    
        def __init__(self, G: int) -> None:
            """Initialize method
    
            Args:
                G: the feature channels of input will be divided evenly into G groups
            """
            super(SimilarityNet, self).__init__()
    
            self.conv0 = ConvBnReLU3D(in_channels=G, out_channels=16, kernel_size=1, stride=1, pad=0)
            self.conv1 = ConvBnReLU3D(in_channels=16, out_channels=8, kernel_size=1, stride=1, pad=0)
            self.similarity = nn.Conv3d(in_channels=8, out_channels=1, kernel_size=1, stride=1, padding=0)
    
        def forward(self, x1: torch.Tensor, grid: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
            """Forward method for SimilarityNet
    
            Args:
                x1: [B, G, Ndepth, H, W], where G is the number of groups, aggregated cost among all the source views with
                    pixel-wise view weight
                grid: position of sampling points in adaptive spatial cost aggregation, (B, evaluate_neighbors*H, W, 2)
                weight: weight of sampling points in adaptive spatial cost aggregation, combination of
                    feature weight and depth weight, [B,Ndepth,1,H,W]
    
            Returns:
                final cost: in the shape of [B,Ndepth,H,W]
            """
    
            batch, G, num_depth, height, width = x1.size()
            num_neighbors = grid.size()[1] // height
    
            # [B,Ndepth,num_neighbors,H,W]
            x1 = F.grid_sample(
                input=self.similarity(self.conv1(self.conv0(x1))).squeeze(1),
                grid=grid,
                mode="bilinear",
                padding_mode="border",
                align_corners=False
            ).view(batch, num_depth, num_neighbors, height, width)
    
            return torch.sum(x1 * weight, dim=2)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46

    经过上述patchmatch后获得depth_sample, score, view_weights,再使用精炼存储不同阶段匹配的深度图depth_patchmatch:

     # step 3. Refinement
            depth = self.upsample_net(ref_image, depth, depth_min, depth_max)
            if ref_width != orig_width or ref_height != orig_height:
                depth = F.interpolate(depth, size=[orig_height, orig_width], mode='bilinear', align_corners=False)
            depth_patchmatch[0] = [depth]
    
            if self.training:
                return depth, torch.empty(0, device=device), depth_patchmatch
            else:
                num_depth = self.patchmatch_num_sample[0]
                score_sum4 = 4 * F.avg_pool3d(
                    F.pad(score.unsqueeze(1), pad=(0, 0, 0, 0, 1, 2)), (4, 1, 1), stride=1, padding=0
                ).squeeze(1)
                # [B, 1, H, W]
                depth_index = depth_regression(
                    score, depth_values=torch.arange(num_depth, device=score.device, dtype=torch.float)
                ).long().clamp(0, num_depth - 1)
                photometric_confidence = torch.gather(score_sum4, 1, depth_index)
                photometric_confidence = F.interpolate(
                    photometric_confidence, size=[orig_height, orig_width], mode="nearest").squeeze(1)
    
                return depth, photometric_confidence, depth_patchmatch
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22

    将最终后的的depth_patchmatch与depth_gt,mask传入损失计算方法中求得损失:

    def patchmatchnet_loss(
        depth_patchmatch: Dict[int, List[torch.Tensor]],
        depth_gt: List[torch.Tensor],
        mask: List[torch.Tensor],
    ) -> torch.Tensor:
        """Patchmatch Net loss function
    
        Args:
            depth_patchmatch: depth map predicted by patchmatch net
            depth_gt: ground truth depth map
            mask: mask for filter valid points
    
        Returns:
            loss: result loss value
        """
        loss = 0
        for i in range(0, 4):
            gt_depth = depth_gt[i][mask[i].bool()]
            for depth in depth_patchmatch[i]:
                loss = loss + F.smooth_l1_loss(depth[mask[i].bool()], gt_depth, reduction="mean")
    
        return loss
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23

    注意这里由于mask用于确定depth_gt中符合条件的取值,故mask为逻辑值比较合适,故需要添加.bool(),不然报错,这里获得的损失是4个阶段所有损失的之和,将它反向传播到网络中。

  • 相关阅读:
    1107 老鼠爱大米 – PAT乙级真题
    洛谷 P2491 [SDOI2011] 消防(树的直径,二分)
    【C】语言文件操作(一)
    Day09字符流&缓冲流&序列化流&IO框架
    IPV6笔记
    SSM的技术论坛含前后台计算机毕业论文Java项目源码下载
    在线编码、格式转换
    【活动系列】那些年写的比较愚蠢的代码
    通过循环查找完数
    考研分享第1期 | 末9生物跨专业考研北京大学电子信息404分经验分享
  • 原文地址:https://blog.csdn.net/m0_46256255/article/details/134017063