• [论文阅读笔记18] DiffusionDet论文笔记与代码解读



    扩散模型近期在图像生成领域很火, 没想到很快就被用在了检测上. 打算对这篇论文做一个笔记.

    论文地址: 论文
    代码: 代码


    0. 扩散模型简述

    首先介绍什么是扩散模型. 我们考虑生成任务, 即encoder-decoder形式的模型, encoder提取输入的抽象信息, 并尝试在decoder中恢复出来. 扩散模型就是这一类中的方法, 其灵感由热力学而来, 基本做法是在输入中逐步加噪, 并学会如何在噪声中恢复出输入. 假定加噪的过程为Markov过程.

    扩散模型和GAN, VAE虽然同为生成式模型, 但其思想不同. GAN是将模型分为生成器与鉴别器两个部分, 生成器的目的是让鉴别器分不出她的输出并非来自于真实数据集合, 而鉴别器的目的是不要被生成器欺骗. 这种博弈的方式有的时候也会陷入一些困境(例如难以到达纳什均衡). VAE得到的潜在变量(latent variable)的维度是小于输入的, 而扩散模型的中间变量的维度与输入相同.

    0.1 加噪的前向过程

    假定原始数据服从分布 x 0 ∼ q ( x ) \mathbf{x}_0\sim q(\mathbf{x}) x0q(x), 现在我们逐步对其加噪, 加入的是高斯噪声. 对于每一步加噪, 我们希望将分布 q q q逐渐向高斯过程靠近, 也即让 q ( x t ∣ x t − 1 ) = N q(\mathbf{x}_t|\mathbf{x}_{t-1})=\mathcal{N} q(xtxt1)=N. 在每一步, 我们假定高斯分布的均值与过去的值 x t − 1 \mathbf{x}_{t-1} xt1有关, 而协方差为固定值(对角阵):

    q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q(\mathbf{x}_t|\mathbf{x}_{t-1})=\mathcal{N}(\mathbf{x}_t;\sqrt{1-\beta_t}\mathbf{x}_{t-1}, \beta_t \mathbf{I}) q(xtxt1)=N(xt;1βt xt1,βtI)

    其中 β t \beta_t βt为一小正数, 在0~1之间.

    注意我们假定加噪过程为Markov过程, 因此当前状态 x t \mathbf{x}_t xt只假定与上一状态 x t − 1 \mathbf{x}_{t-1} xt1有关.

    因此, 当前时刻 x t \mathbf{x}_t xt是由前一时刻 x t − 1 \mathbf{x}_{t-1} xt1决定的正态分布, 其均值为 1 − β t x t − 1 \sqrt{1-\beta_t}\mathbf{x}_{t-1} 1βt xt1, 方差为 β t I \beta_t \mathbf{I} βtI. 为了表示 x t \mathbf{x}_t xt, 这里我们使用一下重参数化(Re-parametrization)技巧. 重参数化是说, 如果我们从一个高斯分布中取样, 也等效于从标准高斯分布中取样, 只不过是加上均值, 以及乘以标准差. 这是因为一个 x ∼ N ( μ , σ 2 ) x\sim\mathcal{N}(\mu, \sigma^2) xN(μ,σ2)的高斯分布可以等价于 μ + σ ϵ \mu+\sigma \epsilon μ+σϵ, 其中 ϵ ∼ N ( 0 , 1 ) \epsilon\sim\mathcal{N}(0,1) ϵN(0,1).

    由高斯分布性质立即得到.

    因此, x t \mathbf{x}_t xt可以表示为:

    x t = 1 − β t x t − 1 + β t ϵ t − 1 \mathbf{x}_t=\sqrt{1-\beta_t} \mathbf{x}_{t-1}+\sqrt{\beta_t}\epsilon_{t-1} xt=1βt xt1+βt ϵt1

    其中 ϵ t − 1 ∼ N ( 0 , I ) \epsilon_{t-1}\sim\mathcal{N}(0,I) ϵt1N(0,I). 为了表示方便, 令 1 − β t = α t \sqrt{1-\beta_t}=\sqrt{\alpha_t} 1βt =αt , 将上式递归展开, 我们有:

    x t = α t ( α t − 1 x t − 2 + 1 − α t − 1 ϵ t − 2 ) + β t ϵ t − 1 \mathbf{x}_t=\sqrt{\alpha_t}(\sqrt{\alpha_{t-1}}\mathbf{x}_{t-2}+\sqrt{1-\alpha_{t-1}}\epsilon_{t-2} )+\sqrt{\beta_t}\epsilon_{t-1} xt=αt (αt1 xt2+1αt1 ϵt2)+βt ϵt1

    我们注意到后面两项可以合并为一个新的高斯分布, 其均值为0, 方差为 1 − α t α t − 1 1-\alpha_t\alpha_{t-1} 1αtαt1, 按此规律展开, 我们得到:

    x t = Π i α i x 0 + 1 − Π i α i ϵ ,    ϵ ∼ N ( 0 , I ) \mathbf{x}_t=\sqrt{\Pi_i\alpha_i}\mathbf{x}_0+\sqrt{1-\Pi_i\alpha_i}\epsilon, ~~\epsilon\sim \mathcal{N}(0,I) xt=Πiαi x0+1Πiαi ϵ,  ϵN(0,I)

    所以, 我们可以直接从 x 0 \mathbf{x}_0 x0得到 x t \mathbf{x}_t xt的分布:

    q ( x t ∣ x 0 ) = N ( x t ; Π i α i x 0 , 1 − Π i α i I ) q(\mathbf{x}_t|\mathbf{x}_0)=\mathcal{N}(\mathbf{x}_t;\sqrt{\Pi_i\alpha_i}\mathbf{x}_0, \sqrt{1-\Pi_i\alpha_i} \mathbf{I}) q(xtx0)=N(xt;Πiαi x0,1Πiαi I)

    所以, 随着时间的增加, x t \mathbf{x}_t xt会越来越趋向于标准正态分布. 以上就是加噪的过程.

    0.2 去噪的反向过程

    我们假定, 在加噪的正向过程中最后的结果已经近似为标准高斯分布 x T ∼ N ( 0 , I ) \mathbf{x}_T \sim \mathcal{N}(0,\mathbf{I}) xTN(0,I). 我们现在希望从加噪后的高斯分布中恢复出原来的信号, 即, 通过逐步计算 q ( x t − 1 ∣ x t ) q(\mathbf{x}_{t-1}|\mathbf{x}_t) q(xt1xt)恢复. 然而, 如果这样计算的话, 需要从整个数据集中采样, 计算量非常大(有可能是因为类似于高斯混合模型的过程), 为此, 我们希望学习出一个模型 p θ p_\theta pθ来学习恢复过程中的条件概率:
    p θ ( x t − 1 ∣ x t ) = N ( x t − 1 , μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)=\mathcal{N}(\mathbf{x}_{t-1},\mu_\theta(\mathbf{x}_t,t), \Sigma_\theta(\mathbf{x}_t,t)) pθ(xt1xt)=N(xt1,μθ(xt,t),Σθ(xt,t))

    我们需要做的是让分布 p ( x t − 1 ∣ x t ) p(\mathbf{x}_{t-1}|\mathbf{x}_t) p(xt1xt)尽可能与 q ( x t − 1 ∣ x t ) q(\mathbf{x}_{t-1}|\mathbf{x}_t) q(xt1xt)接近.

    我们很难计算 q ( x t − 1 ∣ x t ) q(\mathbf{x}_{t-1}|\mathbf{x}_t) q(xt1xt), 但可以考察以 x 0 为条件的以下概率 \mathbf{x}_0为条件的以下概率 x0为条件的以下概率:

    q ( x t − 1 ∣ x t , x 0 ) q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0) q(xt1xt,x0)

    根据Bays公式, 有:

    q ( x t − 1 ∣ x t , x 0 ) = q ( x t , x t − 1 , ∣ x 0 ) q ( x t ∣ x 0 ) = q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0)=\frac{q(\mathbf{x}_t,\mathbf{x}_{t-1},|\mathbf{x}_0)}{q(\mathbf{x}_t|\mathbf{x}_0)}=q(\mathbf{x}_t|\mathbf{x}_{t-1}, \mathbf{x}_0)\frac{q(\mathbf{x}_{t-1}|\mathbf{x}_0)}{q(\mathbf{x}_t|\mathbf{x}_0)} q(xt1xt,x0)=q(xtx0)q(xt,xt1,x0)=q(xtxt1,x0)q(xtx0)q(xt1x0)

    扩散过程为Markov过程, 因此有:

    q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0)=q(\mathbf{x}_t|\mathbf{x}_{t-1})\frac{q(\mathbf{x}_{t-1}|\mathbf{x}_0)}{q(\mathbf{x}_t|\mathbf{x}_0)} q(xt1xt,x0)=q(xtxt1)q(xtx0)q(xt1x0)

    代入高斯分布表达式, 并凑出均值和方差(整理成 exp ⁡ { 1 2 σ 2 ( x t − μ ) 2 } \exp\{\frac{1}{2\sigma^2}(x_t-\mu)^2\} exp{2σ21(xtμ)2}的形式), 我们得到 q ( x t − 1 ∣ x t , x 0 ) q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0) q(xt1xt,x0)均值为:

    μ = α t ( 1 − Π i T − 1 α i ) 1 − Π i T α i x t + Π i T − 1 α i ( 1 − α t ) 1 − Π i T α i x 0 \mu=\frac{\sqrt{\alpha_t}(1-\Pi_i^{T-1}\alpha_i)}{1-\Pi_i^{T}\alpha_i}\mathbf{x}_{t}+\frac{\sqrt{\Pi_i^{T-1}\alpha_i}(1-\alpha_t)}{1-\Pi_i^{T}\alpha_i}\mathbf{x}_{0} μ=1ΠiTαiαt (1ΠiT1αi)xt+1ΠiTαiΠiT1αi (1αt)x0
    根据前面的重参数化技巧, 有 x t = Π i α i x 0 + 1 − Π i α i ϵ t ,    ϵ t \mathbf{x}_t=\sqrt{\Pi_i\alpha_i}\mathbf{x}_0+\sqrt{1-\Pi_i\alpha_i}\epsilon_t, ~~\epsilon_t xt=Πiαi x0+1Πiαi ϵt,  ϵt为网络在这一步预测的高斯噪声, 代入上式得到:

    μ = 1 α t ( x t − 1 − α t 1 − Π i T α i ϵ ) \mu=\frac{1}{\sqrt{\alpha_t}}(\mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\Pi_i^{T}\alpha_i}}\epsilon) μ=αt 1(xt1ΠiTαi 1αtϵ)

    方差为:

    Σ = 1 − Π i T − 1 α i 1 − Π i T α i \Sigma=\frac{1-\Pi_i^{T-1}\alpha_i}{1-\Pi_i^{T}\alpha_i} Σ=1ΠiTαi1ΠiT1αi

    所以

    q ( x t − 1 ∣ x t , x 0 ) ∼ N ( μ , Σ ) q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0) \sim \mathcal{N}(\mu, \Sigma) q(xt1xt,x0)N(μ,Σ)

    所以, 逆扩散的过程为: 根据网络从 x t \mathbf{x}_t xt预测的噪声 ϵ t \epsilon_t ϵt计算出均值与方差, 进而得到 q ( x t − 1 ∣ x t , x 0 ) q(\mathbf{x}_{t-1}|\mathbf{x}_t, \mathbf{x}_0) q(xt1xt,x0), 作为 p θ ( x t − 1 ∣ x t ) p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t) pθ(xt1xt)的近似, 如此得到 x t − 1 \mathbf{x}_{t-1} xt1, 再根据 x t − 1 \mathbf{x}_{t-1} xt1预测出下一步的噪声 ϵ t − 1 \epsilon_{t-1} ϵt1, 如此往复, 如下图所示(图源知乎)

    在这里插入图片描述

    0.3 采样过程的加速

    然而, 如果按照上述方式更新, 则采样速度非常慢. 一种方式是我们可以跨步采样, 也就是一共 T T T个恢复时长, 我们每隔 ⌈ T / S ⌉ \lceil T/S \rceil T/S步采样一次, 这样只需要采样 S S S次.

    另一种方法是, 我们直接通过前向加噪过程的变形来计算当前的恢复过程. 前向加噪过程与原始输入 x 0 \mathbf{x}_0 x0的关系为:

    x t = Π i α i x 0 + 1 − Π i α i ϵ ,    ϵ ∼ N ( 0 , I ) \mathbf{x}_t=\sqrt{\Pi_i\alpha_i}\mathbf{x}_0+\sqrt{1-\Pi_i\alpha_i}\epsilon, ~~\epsilon\sim \mathcal{N}(0,I) xt=Πiαi x0+1Πiαi ϵ,  ϵN(0,I)

    为了表示方便, 下面以 α ˉ k \bar{\alpha}_k αˉk表示 Π i = 1 k α i \Pi_{i=1}^k\alpha_i Πi=1kαi.

    在噪声恢复过程中, 我们以网络预测的噪声 ϵ t \epsilon_t ϵt估计加噪过程中加入的噪声, 即

    x t = α ˉ t x 0 + 1 − α ˉ t ϵ t \mathbf{x}_t=\sqrt{\bar{\alpha}_t}\mathbf{x}_0+\sqrt{1-\bar{\alpha}_t}\epsilon_t xt=αˉt x0+1αˉt ϵt

    翻转的过程用 x t \mathbf{x}_t xt估计 x t − 1 \mathbf{x}_{t-1} xt1, 将上式的 t t t换成 t − 1 t-1 t1, 有:

    x t − 1 = α ˉ t − 1 x 0 + 1 − α ˉ t − 1 ϵ t − 1 \mathbf{x}_{t-1}=\sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0+\sqrt{1-\bar{\alpha}_{t-1}}\epsilon_{t-1} xt1=αˉt1 x0+1αˉt1 ϵt1

    但我们在 x t \mathbf{x}_t xt时刻只能得到该时刻的噪声预测 ϵ t \epsilon_t ϵt, 因此对上式做恒等变换:

    x t − 1 = α ˉ t − 1 x 0 + 1 − α ˉ t − 1 ϵ t − 1 = α ˉ t − 1 x 0 + 1 − α ˉ t − 1 − σ t 2 ϵ t + σ t ϵ

    xt1=α¯t1x0+1α¯t1ϵt1=α¯t1x0+1α¯t1σt2ϵt+σtϵ" role="presentation">xt1=α¯t1x0+1α¯t1ϵt1=α¯t1x0+1α¯t1σt2ϵt+σtϵ
    xt1=αˉt1 x0+1αˉt1 ϵt1=αˉt1 x0+1αˉt1σt2 ϵt+σtϵ

    该式也可以理解为给采样增加不确定度 σ t ϵ \sigma_t\boldsymbol{\epsilon} σtϵ, 实际上DiffusionDet采样正是采用的这个公式.

    所以

    q σ ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; α ˉ t − 1 x 0 + 1 − α ˉ t − 1 − σ t 2 x t − α ˉ t x 0 1 − α ˉ t , σ t 2 I ) q_\sigma(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_{t-1}; \sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_{t-1} - \sigma_t^2} \frac{\mathbf{x}_t - \sqrt{\bar{\alpha}_t}\mathbf{x}_0}{\sqrt{1 - \bar{\alpha}_t}}, \sigma_t^2 \mathbf{I}) qσ(xt1xt,x0)=N(xt1;αˉt1 x0+1αˉt1σt2 1αˉt xtαˉt x0,σt2I)

    对比形式 q ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; μ ~ ( x t , x 0 ) , β ~ t I ) q(\mathbf{x}_{t-1} \vert \mathbf{x}_t, \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_{t-1}; \tilde{\boldsymbol{\mu}}(\mathbf{x}_t, \mathbf{x}_0), \tilde{\beta}_t \mathbf{I}) q(xt1xt,x0)=N(xt1;μ~(xt,x0),β~tI), 得到

    β ~ t = σ t 2 = 1 − α ˉ t − 1 1 − α ˉ t ⋅ β t = 1 − α ˉ t − 1 1 − α ˉ t ⋅ ( 1 − α t ) \tilde{\beta}_t = \sigma_t^2 = \frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \cdot \beta_t=\frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \cdot (1-\alpha_t) β~t=σt2=1αˉt1αˉt1βt=1αˉt1αˉt1(1αt)

    在实践中, 令 σ t 2 = η ⋅ β ~ t \sigma_t^2 = \eta \cdot \tilde{\beta}_t σt2=ηβ~t来控制采样的随机程度. 当 η = 0 \eta = 0 η=0的时候, 表明采样过程是完全确定的(由网络预测的 ϵ t \boldsymbol{\epsilon}_t ϵt决定, 消除了另一个随机因子 ϵ \boldsymbol{\epsilon} ϵ的影响).

    总结一下, 可以在翻转过程中进行如下步骤来提高速度:

    1. 在第 t t t步, 获取 α ˉ t , α ˉ t − 1 , α t \bar{\alpha}_t, \bar{\alpha}_{t-1}, \alpha_t αˉt,αˉt1,αt
    2. 获取网络预测的噪声 ϵ t \boldsymbol{\epsilon}_t ϵt
    3. 计算 σ t = η 1 − α ˉ t − 1 1 − α ˉ t ⋅ ( 1 − α t ) \sigma_t=\eta \sqrt{\frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \cdot (1-\alpha_t)} σt=η1αˉt1αˉt1(1αt)
    4. 计算 x t − 1 = α ˉ t − 1 x 0 + 1 − α ˉ t − 1 − σ t 2 ϵ t + σ t ϵ \mathbf{x}_{t-1} = \sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_{t-1} - \sigma_t^2} \boldsymbol{\epsilon}_t + \sigma_t\boldsymbol{\epsilon} \\ xt1=αˉt1 x0+1αˉt1σt2 ϵt+σtϵ
    5. 直至 t = 0 t=0 t=0

    0.4 损失函数

    我们得到 q ( x t − 1 ∣ x t , x 0 ) q(\textbf{x}_{t-1}|\textbf{x}_t, \textbf{x}_0) q(xt1xt,x0)的目的是: 让网络学习的 p p p q ( x t − 1 ∣ x t , x 0 ) q(\textbf{x}_{t-1}|\textbf{x}_t, \textbf{x}_0) q(xt1xt,x0)尽量接近, 或说用 q ( x t − 1 ∣ x t , x 0 ) q(\textbf{x}_{t-1}|\textbf{x}_t, \textbf{x}_0) q(xt1xt,x0)指导 p p p的训练(可认为最小化二者之间的散度).

    具体损失函数再补充.

    1. DiffusionDet

    1.1 总览

    DiffusionDet的思想非常直接, 既然目标检测是要准确地定位边界框的位置, 那么利用Diffusion Model的强大噪声恢复(学习)能力就可以优化检测的结果. 整体框架如下:

    在这里插入图片描述
    上图中, Image Encoder(为ResNet或Swin Transformer)提取图像的特征, 然后Detection Decoder接受噪声化的边界框, 并恢复边界框的初始值, 同时预测类别. 整体来说, 需要学习一个网络 f θ f_\theta fθ, 从 z T z_T zT中恢复出 z 0 z_0 z0, 其中 z z z为边界框. 损失函数即为恢复的值与初始值的差的2-范数:

    L t r a i n = 1 2 ∣ ∣ f θ ( z t , t ) − z 0 ∣ ∣ 2 \mathcal{L}_{train}=\frac{1}{2}||f_\theta (z_t,t)-z_0||^2 Ltrain=21∣∣fθ(zt,t)z02

    如上图所示, 为了减少计算量, Diffusion Model从原始图片提取的高级特征中学习. Image Encoder就是提取图像特征的, 作者采用了ResNet和SwinTransformer.

    而Detection Decoder接受加噪的bbox和特征图, 并返回恢复的bbox.

    1.2 训练过程

    训练过程的每次迭代大致分为四步:

    1. Encoder提取特征
    2. 从标准高斯分布中采样给真值框加噪, 公式: x t = Π i α i x 0 + 1 − Π i α i ϵ ,    ϵ ∼ N ( 0 , I ) \mathbf{x}_t=\sqrt{\Pi_i\alpha_i}\mathbf{x}_0+\sqrt{1-\Pi_i\alpha_i}\epsilon, ~~\epsilon\sim \mathcal{N}(0,I) xt=Πiαi x0+1Πiαi ϵ,  ϵN(0,I)
    3. 将加噪后的真值框和特征图输入到要学习的encoder网络 f θ f_\theta fθ
    4. 计算loss

    伪代码:

    在这里插入图片描述

    训练过程中有几个细节:

    1. 保证每次迭代输入到encoder中的框数目都相同. 作者通过尝试重复GT框, 以及concat随机大小的框或与图像大小相同的框, 发现还是concat随机大小的框效果最好
    2. 训练损失. 对于预测的框采用集合预测损失. 值得注意的是, 为每个gt框分配 k k k个预测框, 而 k k k个预测框的选取是利用指派问题进行分配(类似匈牙利算法).

    1.3 推理过程

    推理过程大致分为三步:

    1. Encoder提取特征
    2. 从标准高斯分布中产生边界框
    3. T T T 0 0 0, 将随机框, 特征和时间输入到decoder中, 逐步恢复出初始边界框. 恢复的过程具体是:
    1. 在第 t t t步, 获取 α ˉ t , α ˉ t − 1 , α t \bar{\alpha}_t, \bar{\alpha}_{t-1}, \alpha_t αˉt,αˉt1,αt
    2. 获取网络预测的噪声 ϵ t \boldsymbol{\epsilon}_t ϵt
    3. 计算 σ t = η 1 − α ˉ t − 1 1 − α ˉ t ⋅ ( 1 − α t ) \sigma_t=\eta \sqrt{\frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \cdot (1-\alpha_t)} σt=η1αˉt1αˉt1(1αt)
    4. 计算 x t − 1 = α ˉ t − 1 x 0 + 1 − α ˉ t − 1 − σ t 2 ϵ t + σ t ϵ \mathbf{x}_{t-1} = \sqrt{\bar{\alpha}_{t-1}}\mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_{t-1} - \sigma_t^2} \boldsymbol{\epsilon}_t + \sigma_t\boldsymbol{\epsilon} \\ xt1=αˉt1 x0+1αˉt1σt2 ϵt+σtϵ
    5. 直至 t = 0 t=0 t=0

    伪代码:

    在这里插入图片描述

    在推理过程中值得注意的是bbox更新机制. 由于输入的是固定数量的随机框, 在训练阶段我们也是加入了随机框来使数目一样, 因此输出的有些是对应于GT的bbox, 有些则是随机的. 如果把随机的再一起喂到下一步, 作者说这样就破坏了原本的分布, 因此对于每一步预测的框, 将置信度过低的舍弃, 并以新的随机框补充.

    2. 代码解读

    首先看一下./diffusiondet/detector.py中的DiffusionDet类, 其是该论文的核心代码. 其中的forward函数:

    def forward(self, batched_inputs, do_postprocess=True):
            images, images_whwh = self.preprocess_image(batched_inputs)  # 预处理 归一化&填充
            if isinstance(images, (list, torch.Tensor)):
                images = nested_tensor_from_tensor_list(images)
    
            # Feature Extraction.
            src = self.backbone(images.tensor)  # Encoder 提取各级特征
            features = list()
            for f in self.in_features:
                feature = src[f]
                features.append(feature)
    
            # Prepare Proposals.
            if not self.training:  # 如果是推理阶段
                results = self.ddim_sample(batched_inputs, features, images_whwh, images)  # 从T时刻至0时刻 逐步采样恢复
                return results
    
            if self.training:  # 训练阶段
                gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
                targets, x_boxes, noises, t = self.prepare_targets(gt_instances)  # prepare_targets: 对GT框逐步加噪
                t = t.squeeze(-1)
                x_boxes = x_boxes * images_whwh[:, None, :]
    
                outputs_class, outputs_coord = self.head(features, x_boxes, t, None)  # 经过RCNNhead 预测类别和bbox
                output = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
    
                if self.deep_supervision:
                    output['aux_outputs'] = [{'pred_logits': a, 'pred_boxes': b}
                                             for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
    
                loss_dict = self.criterion(output, targets)  # 计算loss
                weight_dict = self.criterion.weight_dict
                for k in loss_dict.keys():
                    if k in weight_dict:
                        loss_dict[k] *= weight_dict[k]
                return loss_dict
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36

    可以看到, 里面还有两个重点的self.prepare_targets(训练过程中的加噪)和self.ddim_sample(推理过程中的采样)

        def prepare_targets(self, targets):
            new_targets = []
            diffused_boxes = []
            noises = []
            ts = []
            for targets_per_image in targets:
                target = {}
                h, w = targets_per_image.image_size
                image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device)
                gt_classes = targets_per_image.gt_classes
                gt_boxes = targets_per_image.gt_boxes.tensor / image_size_xyxy
                gt_boxes = box_xyxy_to_cxcywh(gt_boxes)  # 以上预处理真值框
                d_boxes, d_noise, d_t = self.prepare_diffusion_concat(gt_boxes)  # 核心部分 计算加噪后的框
                diffused_boxes.append(d_boxes)
                noises.append(d_noise)
                ts.append(d_t)
                target["labels"] = gt_classes.to(self.device)
                target["boxes"] = gt_boxes.to(self.device)
                target["boxes_xyxy"] = targets_per_image.gt_boxes.tensor.to(self.device)
                target["image_size_xyxy"] = image_size_xyxy.to(self.device)
                image_size_xyxy_tgt = image_size_xyxy.unsqueeze(0).repeat(len(gt_boxes), 1)
                target["image_size_xyxy_tgt"] = image_size_xyxy_tgt.to(self.device)
                target["area"] = targets_per_image.gt_boxes.area().to(self.device)
                new_targets.append(target)  # target为蕴含大小、类别等信息的真值
            # 返回真值、加噪后的框、噪声和步长
            return new_targets, torch.stack(diffused_boxes), torch.stack(noises), torch.stack(ts)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26

    其中的加噪过程在self.prepare_diffusion_concat(gt_boxes), 我们可以看到:

    def prepare_diffusion_concat(self, gt_boxes):
            """
            :param gt_boxes: (cx, cy, w, h), normalized
            :param num_proposals:
            """
            t = torch.randint(0, self.num_timesteps, (1,), device=self.device).long()  # 确定随机步长
            noise = torch.randn(self.num_proposals, 4, device=self.device)  # 产生标准正态分布
    
            num_gt = gt_boxes.shape[0]  # gt框数目
            if not num_gt:  # generate fake gt boxes if empty gt boxes
                gt_boxes = torch.as_tensor([[0.5, 0.5, 1., 1.]], dtype=torch.float, device=self.device)
                num_gt = 1
    
            if num_gt < self.num_proposals:  # 如果gt框比预设的固定数目小 则随机再填充一些框
                box_placeholder = torch.randn(self.num_proposals - num_gt, 4,
                                              device=self.device) / 6. + 0.5  # 3sigma = 1/2 --> sigma: 1/6
                box_placeholder[:, 2:] = torch.clip(box_placeholder[:, 2:], min=1e-4)
                x_start = torch.cat((gt_boxes, box_placeholder), dim=0)
            elif num_gt > self.num_proposals:  # 如果比预设数目多 就随机抹掉一些GT框
                select_mask = [True] * self.num_proposals + [False] * (num_gt - self.num_proposals)
                random.shuffle(select_mask)
                x_start = gt_boxes[select_mask]
            else:
                x_start = gt_boxes
    
            x_start = (x_start * 2. - 1.) * self.scale
    
            # noise sample
            x = self.q_sample(x_start=x_start, t=t, noise=noise)  # 前向加噪过程
    
            x = torch.clamp(x, min=-1 * self.scale, max=self.scale)  # 限制范围
            x = ((x / self.scale) + 1) / 2.
    
            diff_boxes = box_cxcywh_to_xyxy(x)
    
            return diff_boxes, noise, t
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36

    最后再来看看推理阶段的self.ddim_sample函数:

    @torch.no_grad()
        def ddim_sample(self, batched_inputs, backbone_feats, images_whwh, images, clip_denoised=True, do_postprocess=True):
            batch = images_whwh.shape[0]
            shape = (batch, self.num_proposals, 4)
            total_timesteps, sampling_timesteps, eta, objective = self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
    
            # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
            times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1)  
            times = list(reversed(times.int().tolist()))  # 时间为倒序 从T到0
            time_pairs = list(zip(times[:-1], times[1:]))  # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
    
            img = torch.randn(shape, device=self.device)  # 产生标准高斯分布bboxs
    
            ensemble_score, ensemble_label, ensemble_coord = [], [], []
            x_start = None
            for time, time_next in time_pairs:  # 相邻时间两步计算
                time_cond = torch.full((batch,), time, device=self.device, dtype=torch.long)
                self_cond = x_start if self.self_condition else None  
    
                # 预测的噪声、x_0和类别与坐标
                preds, outputs_class, outputs_coord = self.model_predictions(backbone_feats, images_whwh, img, time_cond,
                                                                             self_cond, clip_x_start=clip_denoised)
                pred_noise, x_start = preds.pred_noise, preds.pred_x_start  # 获取预测的噪声\epsilon_t 与 预测的初始状态x_0
    
                if self.box_renewal:  # filter  Box reneral机制 将置信度低的边界框用随机框替换
                    score_per_image, box_per_image = outputs_class[-1][0], outputs_coord[-1][0]
                    threshold = 0.5
                    score_per_image = torch.sigmoid(score_per_image)
                    value, _ = torch.max(score_per_image, -1, keepdim=False)
                    keep_idx = value > threshold
                    num_remain = torch.sum(keep_idx)
    
                    pred_noise = pred_noise[:, keep_idx, :]
                    x_start = x_start[:, keep_idx, :]
                    img = img[:, keep_idx, :]
                if time_next < 0:
                    img = x_start
                    continue
                
                # 获取\alpha_i的连乘值
                alpha = self.alphas_cumprod[time]
                alpha_next = self.alphas_cumprod[time_next]
    
                sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
                c = (1 - alpha_next - sigma ** 2).sqrt()
    
                noise = torch.randn_like(img)  # 标准高斯分布中采样
    
                # 公式: x_{t-1} = \sqrt{\bar{\alpha}_{t-1}} * x_0 + 
                # \sqrt(1 - \bar{\alpha}_{t-1}} - sigma^2} * \epsilon_t + 
                # \sigma * \epsilon
                img = x_start * alpha_next.sqrt() + \
                      c * pred_noise + \
                      sigma * noise  # 通过预测的噪声 计算恢复结果
    
                if self.box_renewal:  # filter
                    # replenish with randn boxes
                    img = torch.cat((img, torch.randn(1, self.num_proposals - num_remain, 4, device=img.device)), dim=1)
                if self.use_ensemble and self.sampling_timesteps > 1:
                    box_pred_per_image, scores_per_image, labels_per_image = self.inference(outputs_class[-1],
                                                                                            outputs_coord[-1],
                                                                                            images.image_sizes)
                    ensemble_score.append(scores_per_image)
                    ensemble_label.append(labels_per_image)
                    ensemble_coord.append(box_pred_per_image)
    
            if self.use_ensemble and self.sampling_timesteps > 1:
                box_pred_per_image = torch.cat(ensemble_coord, dim=0)
                scores_per_image = torch.cat(ensemble_score, dim=0)
                labels_per_image = torch.cat(ensemble_label, dim=0)
                if self.use_nms:
                    keep = batched_nms(box_pred_per_image, scores_per_image, labels_per_image, 0.5)
                    box_pred_per_image = box_pred_per_image[keep]
                    scores_per_image = scores_per_image[keep]
                    labels_per_image = labels_per_image[keep]
    
                result = Instances(images.image_sizes[0])
                result.pred_boxes = Boxes(box_pred_per_image)
                result.scores = scores_per_image
                result.pred_classes = labels_per_image
                results = [result]
            else:
                output = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
                box_cls = output["pred_logits"]
                box_pred = output["pred_boxes"]
                results = self.inference(box_cls, box_pred, images.image_sizes)
            if do_postprocess:  # 后处理
                processed_results = []
                for results_per_image, input_per_image, image_size in zip(results, batched_inputs, images.image_sizes):
                    height = input_per_image.get("height", image_size[0])
                    width = input_per_image.get("width", image_size[1])
                    r = detector_postprocess(results_per_image, height, width)
                    processed_results.append({"instances": r})
                return processed_results
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
  • 相关阅读:
    redis的数据类型
    linux命令汇总
    书籍《软技能》读后感
    ppt编辑技巧+提升效率的快捷键(2013以上版本)
    使用TortoiseGit导出两次提交时间之间的差异文件
    输出分离与输出抽象
    Django 模型层的操作(Django-05 )
    wireshark抓包本地IDEA xml格式报文教程以及postman调用接口
    群面的技巧
    【MATLAB高级编程】入门篇 | 向量化编程
  • 原文地址:https://blog.csdn.net/wjpwjpwjp0831/article/details/127973262