• 强化学习基础-标量对矩阵的求导术


    以下来自于知乎文章《机器学习中的数学理论1:三步搞定矩阵求导》

    在机器学习,控制论中总会遇到这样或那样需要借助矩阵或者向量求导才能解决的问题(例:Gradient Descent)。这类问题对于在机器学习中分析,推导,应用其原理性理论有很重要的作用。

    x : x: x:标量; x : \mathbf x: x:向量; X X X:矩阵

    1.预备背景

    1.1 常用矩阵微分运算法则

    • d ( X ± Y ) = d ( X ) ± d ( Y ) d(X\pm Y)=d(X)\pm d(Y) d(X±Y)=d(X)±d(Y)

    • d ( X Y ) = X d Y + ( d X ) Y d(XY)=XdY+(dX)Y d(XY)=XdY+(dX)Y

    • d ( X T ) = ( d X ) T d(X^T)=(dX)^T d(XT)=(dX)T

    • d t r ( X ) = t r d ( X ) dtr(X)=trd(X) dtr(X)=trd(X)

    • d ( X ⊙ Y ) = d X ⊙ Y + X ⊙ d Y d(X\odot Y)=dX\odot Y+X \odot dY d(XY)=dXY+XdY

    • d X − 1 = − X − 1 d X X − 1 dX^{-1}=-X^{-1}dXX^{-1} dX1=X1dXX1

    • d ∣ X ∣ = ∣ X ∣ t r ( X − 1 d X ) d|X|=|X|tr(X^{-1}dX) dX=Xtr(X1dX)

    • d σ ( X ) = σ ‘ ( X ) ⊙ d X d\sigma(X)=\sigma^{`}(X)\odot dX dσ(X)=σ(X)dX

    1.2 常用矩阵迹运算法则

    • a = t r ( a ) a=tr(a) a=tr(a),当a是标量
    • t r ( A T ) = t r ( A ) tr(A^T)=tr(A) tr(AT)=tr(A)
    • t r ( A ± B ) = t r ( A ) ± t r ( B ) tr(A\pm B)=tr(A)\pm tr(B) tr(A±B)=tr(A)±tr(B)
    • t r ( A B ) = t r ( B A ) = ∑ i , j A i j B i j tr(AB)=tr(BA)=\sum_{i,j}A_{ij}B_{ij} tr(AB)=tr(BA)=i,jAijBij,当 A A A B T B^T BT尺寸相同时
    • t r ( A T ( B ⊙ C ) ) = t r ( ( A ⊙ B ) T C ) = ∑ i , j A i j B i j C i j tr(A^T(B\odot C))=tr((A\odot B)^TC)=\sum_{i,j}A_{ij}B_{ij}C_{ij} tr(AT(BC))=tr((AB)TC)=i,jAijBijCij

    1.3 常用矩阵直积运算法则

    • A ⨂ B ≠ B ⨂ A A\bigotimes B \neq B\bigotimes A AB=BA

    • ( A 1 + A 2 ) ⨂ B = A 1 ⨂ B + A 2 ⨂ B (A_1+A_2)\bigotimes B=A_1\bigotimes B+A_2\bigotimes B (A1+A2)B=A1B+A2B

    • ( A ⨂ B ) ⨂ C = A ⨂ ( B ⨂ C ) (A\bigotimes B)\bigotimes C=A\bigotimes (B\bigotimes C) (AB)C=A(BC)

    • A 1 , A 2 A_1,A_2 A1,A2可以做乘法运算, B 1 , B 2 B_1,B_2 B1,B2可以做乘法运算:
      ( A 1 ⨂ A 2 ) ( B 1 ⨂ B 2 ) = ( A 1 A 2 ) ⨂ ( B 1 B 2 ) (A_1\bigotimes A_2)(B_1\bigotimes B_2)=(A_1A_2)\bigotimes(B_1B_2) (A1A2)(B1B2)=(A1A2)(B1B2)

    • A , B A,B A,B可以求逆:
      ( A ⨂ B ) − 1 = A − 1 ⨂ B − 1 (A \bigotimes B)^{-1}=A^{-1}\bigotimes B^{-1} (AB)1=A1B1
      若不能求逆运算则:
      ( A ⨂ B ) + = A + ⨂ B + (A\bigotimes B)^{+}=A^{+}\bigotimes B^{+} (AB)+=A+B+

    • ( A ⨂ B ) H = A H ⨂ B H (A\bigotimes B)^H=A^H\bigotimes B^H (AB)H=AHBH

    • d e t ( A ⨂ B ) = ( d e t A ) n ( d e t B ) m ( A ∈ C m × m , B ∈ C n × n ) det(A\bigotimes B)=(detA)^n(detB)^m(A\in C^{m\times m},B\in C^{n\times n}) det(AB)=(detA)n(detB)m(ACm×m,BCn×n)

    • t r ( A ⨂ B ) = ( t r A ) ⨂ ( t r B ) tr(A\bigotimes B)=(trA)\bigotimes (trB) tr(AB)=(trA)(trB)

    • r a n k ( A ⨂ B ) = r a n k A ⨂ r a n k B rank(A\bigotimes B)=rankA\bigotimes rankB rank(AB)=rankArankB

    • e I ⨂ A = I ⨂ e A , e A ⨂ I = A ⨂ I e^{I \bigotimes A} = I\bigotimes e^A,e^{A\bigotimes I} = A\bigotimes I eIA=IeA,eAI=AI

    • e ( A ⨂ I n + I m ⨂ B ) = e A ⨂ e B e^{(A\bigotimes I_n+I_m \bigotimes B)}=e^A\bigotimes e^B e(AIn+ImB)=eAeB

    2.标量对矩阵的求导术

    2.1 算法流程

    I n p u t : X , f \mathbf {Input}:X,f Input:X,f

    O u t p u t : ∂ f ∂ X \mathbf{Output}:\frac{\partial f}{\partial X} Output:Xf

    A l g o r i t h m \mathbf{Algorithm} Algorithm:

    1. 根据 f f f寻找 d f df df.
    2. d f df df左右两边套 t r tr tr: t r ( d f ) = d f tr(df)=df tr(df)=df
    3. 根据 d f = t r ( ∂ f T ∂ X d X ) df=tr(\frac{\partial f^T}{\partial X}dX) df=tr(XfTdX)凑出 ∂ f ∂ X \frac{\partial f}{\partial X} Xf

    2.2 习题

    在这里插入图片描述

    2.解: 首先对 f f f左右两边求微分,令 u = X b u=Xb u=Xb:
    1. d f = a T d ( exp ⁡ ( u ) ) = a T exp ⁡ ( u ) d u = a T exp ⁡ ( X b ) ⊙ ( d X b ) 1.df=a^Td(\exp(u))=a^T\exp(u)du=a^T\exp(Xb)\odot(dXb)\\ 1.df=aTd(exp(u))=aTexp(u)du=aTexp(Xb)(dXb)

    2. d f = t r ( d f ) = t r ( a T ( exp ⁡ ( X b ) ⊙ ( d X b ) ) ) = t r ( ( a ⊙ exp ⁡ ( X b ) ) T d X b ) = t r ( b ( a ⊙ exp ⁡ ( X b ) ) T d X ) = t r ( ( a ⊙ exp ⁡ ( X b ) b T ) T d X ) 2.df=tr(df)=tr(a^T(\exp(Xb)\odot(dXb)))\\ =tr((a\odot\exp(Xb))^T dXb)\\ =tr(b(a\odot\exp(Xb))^T dX)\\ =tr((a\odot\exp(Xb)b^T)^T dX)\\ 2.df=tr(df)=tr(aT(exp(Xb)(dXb)))=tr((aexp(Xb))TdXb)=tr(b(aexp(Xb))TdX)=tr((aexp(Xb)bT)TdX)

    3. 由 d f = t r ( ∂ f T ∂ X d X )   ∂ f ∂ X = a ⊙ exp ⁡ ( X b ) b T 3.由df=tr(\frac{\partial f^T}{\partial X}dX)\\\ \frac{\partial f}{\partial X}=a\odot\exp(Xb)b^T 3.df=tr(XfTdX) Xf=aexp(Xb)bT

    在这里插入图片描述

    1. 解:对上述 l l l可知: l = ( X w − y ) T ( X w − y ) l=(Xw-y)^T(Xw-y) l=(Xwy)T(Xwy):
      1. d l = ( X d w ) T ( X w − y ) + ( X w − y ) T ( X d w ) 2. d l = t r ( d l ) = t r ( 2 ( X w − y ) T X d w ) 3. 由 d l = t r ( ∂ f T ∂ w d w ) → ∂ f ∂ w = 2 X T ( X w − y ) 1.dl=(Xdw)^T(Xw-y)+(Xw-y)^T(Xdw)\\ 2.dl = tr(dl)=tr(2(Xw-y)^TXdw)\\ 3.由dl=tr(\frac{\partial f^T}{\partial w}dw)\rightarrow \frac{\partial f}{\partial w}=2X^T(Xw-y) 1.dl=(Xdw)T(Xwy)+(Xwy)T(Xdw)2.dl=tr(dl)=tr(2(Xwy)TXdw)3.dl=tr(wfTdw)wf=2XT(Xwy)
      在这里插入图片描述

    在这里插入图片描述

    在这里插入图片描述

    在这里插入图片描述

    在这里插入图片描述

    在这里插入图片描述

    在这里插入图片描述

    在这里插入图片描述

    在这里插入图片描述

    在这里插入图片描述

    2.3 标量对矩阵求导的链式法则

    2.3.1 向量对向量求导链式法则

    假设向量(列向量)之间存在依赖关系,比如: x → y → z \mathbf x\rightarrow \mathbf y \rightarrow \mathbf z xyz,则有:
    ∂ z ∂ x = ∂ z ∂ y ∂ y ∂ x \frac{\partial \mathbf z}{\partial \mathbf x}=\frac{\partial \mathbf z}{\partial \mathbf y}\frac{\partial \mathbf y}{\partial \mathbf x} xz=yzxy

    2.3.2 标量对多个向量的链式求导法则

    假设向量(列向量)之间存在依赖关系,比如: x → y → z \mathbf x\rightarrow \mathbf y \rightarrow \mathbf z xyz,要求导的是标量 z z z。那么就有: ∂ z ∂ y : n × 1 , ∂ z ∂ x : m × 1 , ∂ y ∂ x : n × m \frac{\partial z}{\partial \mathbf y}:n\times 1,\frac{\partial z}{\partial \mathbf x}:m\times 1,\frac{\partial \mathbf y}{\partial \mathbf x}:n\times m yz:n×1,xz:m×1,xy:n×m,则: ∂ z ∂ x = ( ∂ y ∂ x ) T ∂ z ∂ y \frac{\partial \mathbf z}{\partial \mathbf x}=(\frac{\partial \mathbf y}{\partial \mathbf x})^T\frac{\partial \mathbf z}{\partial \mathbf y} xz=(xy)Tyz。当形式更为复杂有:
    y 1 → y 1 → . . . y n → z \mathbf y_1 \rightarrow \mathbf y_1\rightarrow ...\mathbf y_n\rightarrow z y1y1...ynz
    那链式法则为:
    ∂ z ∂ y 1 = ( ∂ y n ∂ y n − 1 ∂ y n − 1 ∂ y n − 2 . . . ∂ y 2 ∂ y 1 ) T ∂ z ∂ y n \frac{\partial z}{\partial \mathbf y_1}=(\frac{\partial \mathbf y_{n}}{\partial \mathbf y_{n-1}}\frac{\partial \mathbf y_{n-1}}{\partial \mathbf y_{n-2}}...\frac{\partial \mathbf y_2}{\partial \mathbf y_1})^T\frac{\partial z}{\partial \mathbf y_n} y1z=(yn1ynyn2yn1...y1y2)Tynz
    在这里插入图片描述

    2.3.3 标量对多个矩阵的链式求导法则

    在这里插入图片描述

    在这里插入图片描述

  • 相关阅读:
    java的网络编程
    使用kubasz快速搭建Kubernetes集群
    Linux生成动态库
    基于surging网络组件多协议适配的平台化发展
    如何在代码层面提高CPU分支预测效率
    关于企业微信中开发第三方应用遇到的退出问题
    浏览器本地存储webStroage
    SLM2110 600V 2A 逆变电源专用芯片替代IR2110S 移动储能解决方案
    hive-行转列
    vue项目+xlsx+xlsx-style 实现table导出为excel的功能——技能提升
  • 原文地址:https://blog.csdn.net/shengzimao/article/details/125527405