• Pytroch Nerf代码阅读笔记(LLFF 数据集pose 处理和Nerf 网络结构)


     images, poses, bds, render_poses, i_test = load_llff_data(args.datadir, args.factor,
                                                                      recenter=True, bd_factor=.75,
                                                                    spherify=args.spherify)
    
    • 1
    • 2
    • 3

    从load_llff_data 中取出的pose 是一个(20,3,5)的list。20代表一共有20张image,3×5是每一个image 的pose matrix。
    在这里插入图片描述
    pose 蓝色部分包含rotation matrix 和 translation vector,就是平移和旋转,是一般意义上的位姿矩阵T (camera-to-world affine)。 第4列红色的部分,分别代表图像的高height,宽度width,和相机的焦距Focal:在train函数里面有如下代码:

    hwf = poses[0,:3,-1]  // 取出前三行最后一列元素(红色部分)
    poses = poses[:,:3,:4]  // 取出pose里的平移和旋转部分
    ....中间代码略去.......
    H, W, focal = hwf   // 分别赋予 Hieight、Width、focal
    
    • 1
    • 2
    • 3
    • 4

    关于poses_bounds.npy 解释:这个文件存储这一个numpy 的数组:N×17,N 是图像的数量,17 个元素将会被转化为 3*5 的矩阵和两个深度值:视角 到 场景的最近和最远距离。

    blender 数据集 lego 的读取

    介绍代码中的一个参数:arg.white_bkgd:
    在Blender 的数据集图像有四个通道RGBA,其中A表示的是alpha通道,一般情况下就是两个取值【0,1】,当alpha=0 表示该处的pixel是透明的;当alpha=1 表示该处的pixel是不透明的。 而 white_bkgd 这个参数就是负责将透明像素的部分转化为白色的背景,转化的代码部分如下:

            if args.white_bkgd:
                images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
    
    • 1
    • 2

    代码的解读:
    images是Normalize到【0,1】之间的图像,当alpha=0(也就是 images[…,-1:] = 0 ),那么images的像素将设置为1(纯白色);当alpha=1的时候,那么images的像素的就是本来的RGB通道对应的颜色。

    Nerf 代码的阅读:

    Nerf 网络的搭建:

    MLP 网络的创建,

    Input: layer = 0,Position Encoding 后的长度为 63 的vector
    layer =9 时,将第8层的输出(channel=256)和 direction 进行Postion Encoding 之后(channel=27)进行concat

    Output: 第8层的 density 为 alpha 的输出 和第10层的 rgb 3channel 的输出

    netdepth = 8 , netwidth = 256 , input_ch = 63,是指position输入的维度(position encoding 之后的编码),skip = 4, 是因为在论文中 第5层出现了 skip connection.

    model = NeRF(D=args.netdepth, W=args.netwidth,
                     input_ch=input_ch, output_ch=output_ch, skips=skips,
                     input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
    
    • 1
    • 2
    • 3

    Nerf 的 网络构建代码如下:

    class NeRF(nn.Module):
        def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False):
            """ 
            """
            super(NeRF, self).__init__()
            self.D = D
            self.W = W
            ## Position Encoding之后的 位置vector通道数(63)
            self.input_ch = input_ch  
            ## Position Encoding之后的 direction的vector通道数(27)
            self.input_ch_views = input_ch_views
            self.skips = skips   ## 在第4层有跳跃连接
            self.use_viewdirs = use_viewdirs
            
            ## 前8层的MLP实现:输入为63,输出为 256
            self.pts_linears = nn.ModuleList(
                [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)])
            
            ### 构建了第9层的输入为 第8层的输出 和 direction 进行concat,输出为128 维
            self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])
       
            if use_viewdirs:
                self.feature_linear = nn.Linear(W, W) # 第9层 输出256维的向量
                self.alpha_linear = nn.Linear(W, 1) # 第9层输出 density alpha(1维)
                self.rgb_linear = nn.Linear(W//2, 3)
            else:
                self.output_linear = nn.Linear(W, output_ch)
    
        def forward(self, x):
            input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
            h = input_pts
            for i, l in enumerate(self.pts_linears):
                h = self.pts_linears[i](h)
                h = F.relu(h)
                if i in self.skips:
                    h = torch.cat([input_pts, h], -1)
    
            if self.use_viewdirs:
                alpha = self.alpha_linear(h)
                feature = self.feature_linear(h)
                h = torch.cat([feature, input_views], -1) #第9层concat direction 向量
            
                for i, l in enumerate(self.views_linears):
                    h = self.views_linears[i](h)
                    h = F.relu(h)
    
                rgb = self.rgb_linear(h)  ## 输出rgb 3维度向量
                outputs = torch.cat([rgb, alpha], -1)
            else:
                outputs = self.output_linear(h)
    
            return outputs   
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
  • 相关阅读:
    3.11-程序基本的控制语句 3.12-表达式 3.13-数据类型 3.14-常量/变量 3.15-标识符
    An工具介绍之3D工具
    GAN原理及代码实现
    电力巡检/电力抢修行业解决方案:AI+视频技术助力解决巡检监管难题
    U盘文件损坏且无法读取?别着急,教你恢复的绝招!
    前端工程化面试题
    记-Windows环境下Prometheus+alertmanager+windows_exporter+mtail监控部署提起网关日志
    三模块七电平级联H桥整流器电压平衡控制策略Simulink仿真
    Vue基础知识(条件渲染、列表渲染、收集表单数据、过滤器)(三)
    Java内存区域
  • 原文地址:https://blog.csdn.net/qq_41623632/article/details/126468034