• spynet(六):光流整体结构


    16. 网络结构

    3层金字塔图

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-LLnJzWig-1666147293880)(spynet_note_im/20221014114943.png)]

    1. G是各level的网络结构
      这里主要说下网络的输入,包括
      参考帧
      根据光流warp后的辅助帧
      光流
      对于每层而言:
      在这里插入图片描述

    输出的是 光流的残差(上一个level 上采样后得到的 flow与 groundtruth flow之间的差值)
    后面会有体现,interpolate插值函数 和 warp函数后面也会讲解

    class SpyNetUnit(nn.Module):
    
        def __init__(self, input_channels: int = 8):
            super(SpyNetUnit, self).__init__()
    
            self.module = nn.Sequential(
                nn.Conv2d(input_channels, 32, kernel_size=7, padding=3, stride=1),
                nn.ReLU(inplace=False),
    
                nn.Conv2d(32, 64, kernel_size=7, padding=3, stride=1),
                nn.ReLU(inplace=False),
    
                nn.Conv2d(64, 32, kernel_size=7, padding=3, stride=1),
                nn.ReLU(inplace=False),
    
                nn.Conv2d(32, 16, kernel_size=7, padding=3, stride=1),
                nn.ReLU(inplace=False),
    
                nn.Conv2d(16, 2, kernel_size=7, padding=3, stride=1))
    
        def forward(self, 
                    frames: Tuple[torch.Tensor, torch.Tensor], 
                    optical_flow: torch.Tensor = None,
                    upsample_optical_flow: bool = True) -> torch.Tensor:
            f_frame, s_frame = frames
    
            # G的输入是两个图片和对应的光流
            # 在第0层也就是金字塔最上层,输入的光流是[]
    
            if optical_flow is None:
                # If optical flow is None (k = 0) then create empty one having the
                # same size as the input frames, therefore there is no need to 
                # upsample it later
                upsample_optical_flow = False
                b, c, h, w = f_frame.size()
                optical_flow = torch.zeros(b, 2, h, w, device=s_frame.device)
    
            # 其他层输入的光流是 上一层光流 的 2倍上采样(size和value都要扩大2倍)
            if upsample_optical_flow:
                optical_flow = F.interpolate(
                    optical_flow, scale_factor=2, align_corners=True, 
                    mode='bilinear') * 2
    
            s_frame = spynet.nn.warp(s_frame, optical_flow, s_frame.device)
            s_frame = torch.cat([s_frame, optical_flow], dim=1)
            
            inp = torch.cat([f_frame, s_frame], dim=1)
            # inp 是  f_frame,s_frame_warp,optical_flow
            return self.module(inp)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    1. 逐层训练

    这套代码是从金字塔最高层到最底层逐个进行训练的:
    总的来说条理不是很清晰,而且逐层训练的目的是什么?

    def train(**kwargs):
        torch.manual_seed(0)
        previous = []
        for k in range(kwargs.pop('levels')):
            previous.append(train_one_level(k, previous, **kwargs))
        # previous 开始为空,最后是一个包含k层的网络
    
        # 训练完成后保存下来 
        final = spynet.SpyNet(previous)
        torch.save(final.state_dict(), 
                   str(Path(kwargs['checkpoint_dir']) / f'final.pt'))
    
    
    def train_one_level(k: int, 
                        previous: Sequence[spynet.SpyNetUnit],
                        **kwargs) -> spynet.SpyNetUnit:
    
        print(f'Training level {k}...')
    
        train_ds, valid_ds = load_data(kwargs['root'], k)
        train_dl, valid_dl = build_dl(train_ds, valid_ds, 
                                      kwargs['batch_size'],
                                      kwargs['dl_num_workers'])
    
        # 返回当前的网络 和 之前的网络, 比如3层的网络和2层的网络
        current_level, trained_pyramid = build_spynets(
            k, kwargs['finetune_name'], previous)
        
        optimizer = torch.optim.AdamW(current_level.parameters(),
                                      lr=1e-5,
                                      weight_decay=4e-5)
        loss_fn = spynet.nn.EPELoss()
    
        for epoch in range(kwargs['epochs']):
            train_one_epoch(train_dl, 
                            optimizer,
                            loss_fn,
                            current_level,
                            trained_pyramid,
                            print_freq=999999,
                            header=f'Epoch [{epoch}] [Level {k}]')
    
        torch.save(current_level.state_dict(), 
                   str(Path(kwargs['checkpoint_dir']) / f'{k}.pt'))
        
        return current_level
    def train_one_epoch(dl: DataLoader,
                        optimizer: torch.optim.AdamW,
                        criterion_fn: torch.nn.Module,
                        Gk: torch.nn.Module, 
                        prev_pyramid: torch.nn.Module = None, 
                        print_freq: int = 100,
                        header: str = ''):
        Gk.train()
        running_loss = 0.
    
        if prev_pyramid is not None:
            prev_pyramid.eval()
    
        for i, (x, y) in enumerate(dl):
            x = x[0].to(device), x[1].to(device)
            y = y.to(device)
    
            if prev_pyramid is not None:
                with torch.no_grad():
                    Vk_1 = prev_pyramid(x)
                    Vk_1 = F.interpolate(
                        Vk_1, scale_factor=2, mode='bilinear', align_corners=True)
            else:
                Vk_1 = None
    
            predictions = Gk(x, Vk_1, upsample_optical_flow=False)
    
            if Vk_1 is not None:
                y = y - Vk_1
    
            loss = criterion_fn(y, predictions)
    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            running_loss += loss.item()
    
            if (i + 1) % print_freq == 0:
                loss_mean = running_loss / i
                print(f'{header} [{i}/{len(dl)}] loss {loss_mean:.4f}')
    
        loss_mean = running_loss / len(dl)
        print(f'{header} loss {loss_mean:.4f}')
    
    # 返回当前的网络 和 之前的网络
    def build_spynets(k: int, name: str, 
                      previous: Sequence[torch.nn.Module]) \
                          -> Tuple[spynet.SpyNetUnit, spynet.SpyNet]:
    
        if name != 'none':
            pretrained = spynet.SpyNet.from_pretrained(name, map_location=device)
            current_train = pretrained.units[k]
        else:
            current_train = spynet.SpyNetUnit()
            
        current_train.to(device)
        current_train.train()
        
        if k == 0:
            Gk = None
        else:
            Gk = spynet.SpyNet(previous)
            Gk.to(device)
            Gk.eval()
    
        return current_train, Gk
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    1. warp 和 epeloss的实现

    F.grid_sample 函数 根据 flow (相当于一个查找表对应像素的位置)查表和插值得到一个新的图像

    import torch
    import torch.nn.functional as F
    
    
    def warp(image: torch.Tensor, 
             optical_flow: torch.Tensor,
             device: torch.device = torch.device('cpu')) -> torch.Tensor:
    
        b, c, im_h, im_w = image.size() 
        
        hor = torch.linspace(-1.0, 1.0, im_w).view(1, 1, 1, im_w)
        hor = hor.expand(b, -1, im_h, -1)
    
        vert = torch.linspace(-1.0, 1.0, im_h).view(1, 1, im_h, 1)
        vert = vert.expand(b, -1, -1, im_w)
    
        grid = torch.cat([hor, vert], 1).to(device)
    
        # optical_flow是对应图像size的,因此首先将其缩放到[-1,1]
        # 再与grid相加
        optical_flow = torch.cat([
            optical_flow[:, 0:1, :, :] / ((im_w - 1.0) / 2.0), 
            optical_flow[:, 1:2, :, :] / ((im_h - 1.0) / 2.0)], dim=1)
    
        # Channels last (which corresponds to optical flow vectors coordinates)
        grid = (grid + optical_flow).permute(0, 2, 3, 1)
        return F.grid_sample(image, grid=grid, padding_mode='border', 
                             align_corners=True)
    
    # 欧式距离
    class EPELoss(torch.nn.Module): #end-point-error (EPE)
    
        def __init__(self):
            super(EPELoss, self).__init__()
        
        def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
            dist = (target - pred).pow(2).sum().sqrt()
            return dist.mean()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38

    17.光流的可视化方法

    参见博客: light flow 光流的常见可视化方法,光流图像生成

    18. torch.nn.functional.interpolate函数

    常用于 tensord的 上采样,下采样操作

    x = Variable(torch.randn([1, 3, 64, 64]))
    y0 = F.interpolate(x, scale_factor=0.5)
    y1 = F.interpolate(x, size=[32, 32])
    
    y2 = F.interpolate(x, size=[128, 128], mode="bilinear")
    
    print(y0.shape)
    print(y1.shape)
    print(y2.shape)
    
    return:
    torch.Size([1, 3, 32, 32])
    torch.Size([1, 3, 32, 32])
    torch.Size([1, 3, 128, 128])
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    19. 光流估计的评价指标

    在这里插入图片描述

    20. 一个比较规整,易懂的spynet 网络模型

    """
    This code is based on Open-MMLab's one.
    https://github.com/open-mmlab/mmediting
    """
    
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from modules import flow_warp
    
    class SPyNet(nn.Module):
        """SPyNet network structure.
        The difference to the SPyNet in [tof.py] is that
            1. more SPyNetBasicModule is used in this version, and
            2. no batch normalization is used in this version.
        Paper:
            Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
        Args:
            pretrained (str): path for pre-trained SPyNet. Default: None.
        """
    
        def __init__(self):
            super().__init__()
    
            self.basic_module = nn.ModuleList(
                [SPyNetBasicModule() for _ in range(6)]
            )
    
            #self.load_state_dict(torch.load('spynet_20210409-c6c1bd09.pth'))
    
            self.register_buffer(
                'mean',
                torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
            self.register_buffer(
                'std',
                torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
    
        def compute_flow(self, ref, supp):
            """Compute flow from ref to supp.
            Note that in this function, the images are already resized to a
            multiple of 32.
            Args:
                ref (Tensor): Reference image with shape of (n, 3, h, w).
                supp (Tensor): Supporting image with shape of (n, 3, h, w).
            Returns:
                Tensor: Estimated optical flow: (n, 2, h, w).
            """
            n, _, h, w = ref.size()
    
            # normalize the input images
            ref = [(ref - self.mean) / self.std]
            supp = [(supp - self.mean) / self.std]
    
            # generate downsampled frames
            for level in range(5):
                ref.append(
                    F.avg_pool2d(
                        input=ref[-1],
                        kernel_size=2,
                        stride=2,
                        count_include_pad=False
                    )
                )
                supp.append(
                    F.avg_pool2d(
                        input=supp[-1],
                        kernel_size=2,
                        stride=2,
                        count_include_pad=False
                    )
                )
            ref = ref[::-1]
            supp = supp[::-1]
    
            # flow computation
            flow = ref[0].new_zeros(n, 2, h // 32, w // 32)
            for level in range(len(ref)):
                if level == 0:
                    flow_up = flow
                else:
                    flow_up = F.interpolate(
                        input=flow,
                        scale_factor=2,
                        mode='bilinear',
                        align_corners=True) * 2.0
    
                # add the residue to the upsampled flow
                flow = flow_up + self.basic_module[level](
                    torch.cat([
                        ref[level],
                        flow_warp(
                            supp[level],
                            flow_up.permute(0, 2, 3, 1),
                            padding_mode='border'), flow_up
                    ], 1))
    
            return flow
    
        def forward(self, ref, supp):
            """Forward function of SPyNet.
            This function computes the optical flow from ref to supp.
            Args:
                ref (Tensor): Reference image with shape of (n, 3, h, w).
                supp (Tensor): Supporting image with shape of (n, 3, h, w).
            Returns:
                Tensor: Estimated optical flow: (n, 2, h, w).
            """
    
            # upsize to a multiple of 32
            h, w = ref.shape[2:4]
            w_up = w if (w % 32) == 0 else 32 * (w // 32 + 1)
            h_up = h if (h % 32) == 0 else 32 * (h // 32 + 1)
            ref = F.interpolate(
                input=ref, size=(h_up, w_up), mode='bilinear', align_corners=False)
            supp = F.interpolate(
                input=supp,
                size=(h_up, w_up),
                mode='bilinear',
                align_corners=False)
    
            # compute flow, and resize back to the original resolution
            flow = F.interpolate(
                input=self.compute_flow(ref, supp),
                size=(h, w),
                mode='bilinear',
                align_corners=False)
    
            # adjust the flow values
            flow[:, 0, :, :] *= float(w) / float(w_up)
            flow[:, 1, :, :] *= float(h) / float(h_up)
    
            return flow
    
    
    class SPyNetBasicModule(nn.Module):
        """Basic Module for SPyNet.
        Paper:
            Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
        """
    
        def __init__(self):
            super().__init__()
    
            self.basic_module = nn.Sequential(
                nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3),
                nn.ReLU(),
                nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3),
                nn.ReLU(),
                nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3),
                nn.ReLU(),
                nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3),
                nn.ReLU(),
                nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)
            )
    
        def forward(self, tensor_input):
            """
            Args:
                tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
                    8 channels contain:
                    [reference image (3), neighbor image (3), initial flow (2)].
            Returns:
                Tensor: Refined flow with shape (b, 2, h, w)
            """
            return self.basic_module(tensor_input)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
  • 相关阅读:
    十四届蓝桥青少组模拟赛Python-20221108
    多功能音乐播放器beaTunes5 mac中文特点
    设计模式学习(三):工厂模式
    如何在公网环境下使用移动端通过群晖管家管理部署自己家里局域网内的黑群晖
    微软警告 Windows 8.1 用户:系统即将停止支持,建议购买 Win11/10 新设备
    npm包管理
    学会 Arthas,让你 3 年经验掌握 5 年功力!
    【毕业设计】基于java+swing+GUI的雷电游戏GUI设计与实现(毕业论文+程序源码)——雷电游戏
    【C语言】八大排序算法
    谈谈前端的本地存储indexedDB
  • 原文地址:https://blog.csdn.net/tywwwww/article/details/127403373