[隐藏左侧目录栏][显示左侧目录栏]

二分类交叉熵损失求导#

本文的推导过程省略了偏置项 b

1、前向传播过程中符号的定义#

模型假设为非常简单的:一个全连接层后面接sigmoid做二分类。那么其前向传播的公式为:

\begin{equation}z^{(i)}=\mathbf{W}\vec{x^{(i)}}\end{equation}
\begin{equation}\hat{y}^{(i)}=\sigma(z^{(i)})\end{equation}

在上述公式中:

  • 右上角的角标 \cdot^{(i)} 表示第 i 条样本;

  • \vec{x^{(i)}} 表示第 i 条样本的输入特征,是一个向量;假设输入样本特征的维度为 d_{input},则 \vec{x^{(i)}} = [x^{(i)}_1, x^{(i)}_2, ..., x^{(i)}_{d_{input}}]

  • \mathbf{W} 是权重矩阵;

  • \sigma(\cdot) 表示sigmoid函数;

  • z^{(i)} 表示全连接层的输出,同时也是sigmoid层的输入,由于这里是二分类,所以其是一个标量;

  • \hat{y}^{(i)} 是sigmoid层的输出,也是一个标量,表示的是模型对第 i 条样本预测为正例的概率,使用该值和 y^{(i)} 比较计算损失;

2、损失函数中的符号定义#

损失函数的公式,如下所示:

\begin{equation}L^{(i)}=-\Big[ y^{(i)} \log \hat{y}^{(i)} + (1-y^{(i)})\log (1-\hat{y}^{(i)}) \Big]\end{equation}
\begin{equation}L=\frac{1}{N} \sum_{i=1}^N L^{(i)}\end{equation}

上述公式中的公式(4)只是对每条样本的损失 L^{(i)} 求均值,N 表示样本的总数,比较简单。后面的分析和求导将只对公式(3)进行。关于公式(3)是如何得到的,详见分类任务损失函数的原理

至此,所有的符号说明完成。

3、二分类交叉熵损失求导#

3.1 说明 \frac{\partial L^{(i)}}{\partial w}\frac{\partial L^{(i)}}{\partial z^{(i)}} 的区别#

根据链式求导法则,\frac{\partial L^{(i)}}{\partial w} 可以分为三部分,如下式所示:

\begin{equation} \frac{\partial L^{(i)}}{\partial w} = \frac{\partial L^{(i)}}{\partial \hat{y}^{(i)}} \frac{\partial \hat{y}^{(i)}}{\partial z^{(i)}} \frac{\partial z^{(i)}}{\partial w} \end{equation}

\frac{\partial L^{(i)}}{\partial z^{(i)}} 可以分为两部分,如下式所示:

\begin{equation} \frac{\partial L^{(i)}}{\partial z^{(i)}} = \frac{\partial L^{(i)}}{\partial \hat{y}^{(i)}} \frac{\partial \hat{y}^{(i)}}{\partial z^{(i)}} \end{equation}

它们之间相差了一个 \frac{\partial z^{(i)}}{\partial w},由前向传播的公式可知,这个导数取决于sigmoid层之前的网络结构是什么样的。在本文的第一部分假设是一个全连接层,所以其导数为 \frac{\partial z^{(i)}}{\partial w}=\vec{x^{(i)}};但其也可以是CNN层、RNN层、Transformer层等等,所以对于不同的网络 \frac{\partial z^{(i)}}{\partial w} 有着不同的求解方式,其并不属于二分类交叉熵损失部分的求导。所以本文后面在求导时求解的是 \frac{\partial L^{(i)}}{\partial z^{(i)}}

3.2 求解 \frac{\partial L^{(i)}}{\partial z^{(i)}}#

先使用链式求导法则求解 \frac{\partial L^{(i)}}{\partial z^{(i)}},求解过程如下所示:

\begin{equation}\begin{split} \frac{\partial L^{(i)}}{\partial z^{(i)}} &= \frac{\partial L^{(i)}}{\partial \hat{y}^{(i)}} \frac{\partial \hat{y}^{(i)}}{\partial z^{(i)}} \\ &= -(\frac{y^{(i)}}{\hat{y}^{(i)}} - \frac{1-y^{(i)}}{1-\hat{y}^{(i)}}) \sigma^{\prime}(z^{(i)}) \\ &= -(\frac{y^{(i)}}{\hat{y}^{(i)}} - \frac{1-y^{(i)}}{1-\hat{y}^{(i)}}) \sigma(z^{(i)})(1-\sigma(z^{(i)})) \\ &= -(\frac{y^{(i)}}{\hat{y}^{(i)}} - \frac{1-y^{(i)}}{1-\hat{y}^{(i)}}) \hat{y}^{(i)}(1-\hat{y}^{(i)}) \\ &= -[y^{(i)}(1-\hat{y}^{(i)}) - \hat{y}^{(i)}(1-y^{(i)})]\\ &= - (y^{(i)} - y^{(i)} \hat{y}^{(i)} - \hat{y}^{(i)} + y^{(i)} \hat{y}^{(i)}) \\ &= \hat{y}^{(i)} - y^{(i)} \end{split}\end{equation}

说明:上述推导中使用到了sigmoid的求导公式,若 f(z)=\frac{1}{1+e^{-z}},则有:f^{\prime}(z) = f(z)(1 - f(z))

至此求解出二分类交叉熵损失的导数为:

\begin{equation}\frac{\partial L^{(i)}}{\partial z^{(i)}}=\hat{y}^{(i)} - y^{(i)}\end{equation}

4、总结#

本文主要是对二分类交叉熵损失进行求导;

Reference#

太早了,不记得了...