[隐藏左侧目录栏][显示左侧目录栏]
二分类交叉熵损失求导
本文的推导过程省略了偏置项 b;
1、前向传播过程中符号的定义
模型假设为非常简单的:一个全连接层后面接sigmoid做二分类。那么其前向传播的公式为:
\begin{equation}z^{(i)}=\mathbf{W}\vec{x^{(i)}}\end{equation}
\begin{equation}\hat{y}^{(i)}=\sigma(z^{(i)})\end{equation}
在上述公式中:
-
右上角的角标 \cdot^{(i)} 表示第 i 条样本;
-
\vec{x^{(i)}} 表示第 i 条样本的输入特征,是一个向量;假设输入样本特征的维度为 d_{input},则 \vec{x^{(i)}} = [x^{(i)}_1, x^{(i)}_2, ..., x^{(i)}_{d_{input}}];
-
\mathbf{W} 是权重矩阵;
-
\sigma(\cdot) 表示sigmoid函数;
-
z^{(i)} 表示全连接层的输出,同时也是sigmoid层的输入,由于这里是二分类,所以其是一个标量;
-
\hat{y}^{(i)} 是sigmoid层的输出,也是一个标量,表示的是模型对第 i 条样本预测为正例的概率,使用该值和 y^{(i)} 比较计算损失;
2、损失函数中的符号定义
损失函数的公式,如下所示:
\begin{equation}L^{(i)}=-\Big[ y^{(i)} \log \hat{y}^{(i)} + (1-y^{(i)})\log (1-\hat{y}^{(i)}) \Big]\end{equation}
\begin{equation}L=\frac{1}{N} \sum_{i=1}^N L^{(i)}\end{equation}
上述公式中的公式(4)只是对每条样本的损失 L^{(i)} 求均值,N 表示样本的总数,比较简单。后面的分析和求导将只对公式(3)进行。关于公式(3)是如何得到的,详见分类任务损失函数的原理;
至此,所有的符号说明完成。
3、二分类交叉熵损失求导
3.1 说明 \frac{\partial L^{(i)}}{\partial w} 与 \frac{\partial L^{(i)}}{\partial z^{(i)}} 的区别
根据链式求导法则,\frac{\partial L^{(i)}}{\partial w} 可以分为三部分,如下式所示:
\begin{equation}
\frac{\partial L^{(i)}}{\partial w} = \frac{\partial L^{(i)}}{\partial \hat{y}^{(i)}} \frac{\partial \hat{y}^{(i)}}{\partial z^{(i)}} \frac{\partial z^{(i)}}{\partial w}
\end{equation}
\frac{\partial L^{(i)}}{\partial z^{(i)}} 可以分为两部分,如下式所示:
\begin{equation}
\frac{\partial L^{(i)}}{\partial z^{(i)}} = \frac{\partial L^{(i)}}{\partial \hat{y}^{(i)}} \frac{\partial \hat{y}^{(i)}}{\partial z^{(i)}}
\end{equation}
它们之间相差了一个 \frac{\partial z^{(i)}}{\partial w},由前向传播的公式可知,这个导数取决于sigmoid层之前的网络结构是什么样的。在本文的第一部分假设是一个全连接层,所以其导数为 \frac{\partial z^{(i)}}{\partial w}=\vec{x^{(i)}};但其也可以是CNN层、RNN层、Transformer层等等,所以对于不同的网络 \frac{\partial z^{(i)}}{\partial w} 有着不同的求解方式,其并不属于二分类交叉熵损失部分的求导。所以本文后面在求导时求解的是 \frac{\partial L^{(i)}}{\partial z^{(i)}}。
3.2 求解 \frac{\partial L^{(i)}}{\partial z^{(i)}}
先使用链式求导法则求解 \frac{\partial L^{(i)}}{\partial z^{(i)}},求解过程如下所示:
\begin{equation}\begin{split}
\frac{\partial L^{(i)}}{\partial z^{(i)}} &= \frac{\partial L^{(i)}}{\partial \hat{y}^{(i)}} \frac{\partial \hat{y}^{(i)}}{\partial z^{(i)}} \\
&= -(\frac{y^{(i)}}{\hat{y}^{(i)}} - \frac{1-y^{(i)}}{1-\hat{y}^{(i)}}) \sigma^{\prime}(z^{(i)}) \\
&= -(\frac{y^{(i)}}{\hat{y}^{(i)}} - \frac{1-y^{(i)}}{1-\hat{y}^{(i)}}) \sigma(z^{(i)})(1-\sigma(z^{(i)})) \\
&= -(\frac{y^{(i)}}{\hat{y}^{(i)}} - \frac{1-y^{(i)}}{1-\hat{y}^{(i)}}) \hat{y}^{(i)}(1-\hat{y}^{(i)}) \\
&= -[y^{(i)}(1-\hat{y}^{(i)}) - \hat{y}^{(i)}(1-y^{(i)})]\\
&= - (y^{(i)} - y^{(i)} \hat{y}^{(i)} - \hat{y}^{(i)} + y^{(i)} \hat{y}^{(i)}) \\
&= \hat{y}^{(i)} - y^{(i)}
\end{split}\end{equation}
说明:上述推导中使用到了sigmoid的求导公式,若 f(z)=\frac{1}{1+e^{-z}},则有:f^{\prime}(z) = f(z)(1 - f(z))
至此求解出二分类交叉熵损失的导数为:
\begin{equation}\frac{\partial L^{(i)}}{\partial z^{(i)}}=\hat{y}^{(i)} - y^{(i)}\end{equation}
4、总结
本文主要是对二分类交叉熵损失进行求导;
Reference
太早了,不记得了...