使用极大似然的思想,首先引入log函数,保证函数单调性不变,那么根据log函数的单调性,想要P(y|x)越大,那么可以让-P(y|x)越小,其实就是说,让其概率值更大,反方向理解就是损失更小才能作为损失函数来用,那么交叉熵损失函数就是多个样本损失函数的和,N个样本的和就是:
L
=
−
∑
i
=
1
N
(
y
i
log
y
^
i
+
(
1
−
y
i
)
log
(
1
−
y
^
i
)
)
(4)
L = -\sum^N_{i=1}(y_{i}\log{\hat{y}_{i}} + (1-y_{i})\log{(1-\hat{y}}_{i}))\tag{4}
L=−i=1∑N(yilogy^i+(1−yi)log(1−y^i))(4)
再从交叉熵损失函数的图像来理解(单个样本损失函数)
横坐标是预测输出,纵坐标是交叉熵损失函数 L。显然,预测输出越接近真实样本标签 1,损失函数 L 越小;预测输出越接近 0,L 越大
预测输出越接近真实样本标签 0,损失函数 L 越小;预测函数越接近 1,L 越大
关于分类问题的损失函数常用交叉熵损失函数,而非均方误差MSE
从两者表达式来看
便于理解,我们用上图做一个简单的推导
Z
(
x
)
=
w
∗
b
,
A
(
z
)
=
σ
(
z
)
=
1
1
+
e
−
z
(5)
Z(x) = w * b, A(z) = σ(z)= \frac{1}{1 + e ^ {-z}} \tag{5}
Z(x)=w∗b,A(z)=σ(z)=1+e−z1(5)
那么MSE损失表达式就是:(A为分类结果的概率值,y为真实分类值,即0或者1)
C
=
(
A
−
y
)
2
2
(6)
C = \frac{(A - y)^2}{2}\tag{6}
C=2(A−y)2(6)
使用梯度下降法的更新w和b时,对w和b进行求导
∂
C
∂
w
=
∂
C
∂
A
∂
A
∂
Z
∂
Z
∂
w
=
(
A
−
y
)
σ
′
(
Z
)
x
=
(
A
−
y
)
A
(
1
−
A
)
x
≈
A
σ
′
(
z
)
(7)
\frac{\partial C}{\partial w} = \frac{\partial C}{\partial A }\frac{\partial A}{\partial Z }\frac{\partial Z}{\partial w } = (A - y)σ'(Z)x\tag{7} = (A - y)A(1-A)x \approx Aσ'(z)
∂w∂C=∂A∂C∂Z∂A∂w∂Z=(A−y)σ′(Z)x=(A−y)A(1−A)x≈Aσ′(z)(7)
同理对b求导
∂
C
∂
b
=
∂
C
∂
A
∂
A
∂
Z
∂
Z
∂
b
=
(
A
−
y
)
σ
′
(
Z
)
=
(
A
−
y
)
A
(
1
−
A
)
≈
A
σ
′
(
z
)
(8)
\frac{\partial C}{\partial b} = \frac{\partial C}{\partial A }\frac{\partial A}{\partial Z }\frac{\partial Z}{\partial b } = (A - y)σ'(Z)\tag{8} = (A - y)A(1-A) \approx Aσ'(z)
∂b∂C=∂A∂C∂Z∂A∂b∂Z=(A−y)σ′(Z)=(A−y)A(1−A)≈Aσ′(z)(8)
更新后的w和b:
w
=
w
−
η
∂
C
∂
w
=
w
−
η
A
σ
′
(
z
)
(9)
w = w - \eta \frac{\partial C}{\partial w} = w - \eta A σ'(z)\tag{9}
w=w−η∂w∂C=w−ηAσ′(z)(9)
b
=
b
−
η
∂
C
∂
b
=
b
−
η
A
σ
′
(
z
)
(10)
b = b - \eta \frac{\partial C}{\partial b} = b - \eta A σ'(z)\tag{10}
b=b−η∂b∂C=b−ηAσ′(z)(10)
交叉熵损失函数同理推导,其中交叉熵误差表达公式为:(其实需要累加,此处方便理解就不累加了)
L
=
−
(
y
∗
l
n
(
a
)
+
(
1
−
y
)
∗
l
n
(
1
−
a
)
)
(11)
L = -(y * ln(a) + (1-y)*ln(1-a))\tag{11}
L=−(y∗ln(a)+(1−y)∗ln(1−a))(11)
推导过程如下:(推导过程可以参考上面mse损失推导过程,(5)依旧可用,求偏导的步骤可以参考(7))
∂
L
∂
w
=
(
−
y
a
+
1
−
y
1
−
a
)
x
σ
′
(
z
)
(12)
\frac{\partial L}{\partial w} = (- \frac{y}{a} + \frac{1-y}{1-a})xσ'(z)\tag{12}
∂w∂L=(−ay+1−a1−y)xσ′(z)(12)
注:σ’(z) = σ(z) * (1 - σ(z)) = a * (1 - a),推导过程如上图手写部分
∂
L
∂
w
=
(
a
y
−
y
+
a
−
a
y
)
x
=
(
a
−
y
)
x
(13)
\frac{\partial L}{\partial w} = (ay -y + a - ay)x = (a-y)x\tag{13}
∂w∂L=(ay−y+a−ay)x=(a−y)x(13)