[隐藏左侧目录栏][显示左侧目录栏]

均方差损失与交叉熵损失的异同#

关于交叉熵损失的求导在之前的文章 二分类交叉熵损失求导 中已经求解过了,本文会先对均方差损失求导,然后利用导数对比均方差损失和交叉熵损失。

1、均方差损失求导#

这里假设使用均方差损失做二分类任务;

1.1 前向传播过程中符号的定义#

模型假设为非常简单的:一个全连接层后面接sigmoid做二分类。那么其前向传播的公式为:

\begin{equation}z^{(i)}=\mathbf{W}\vec{x^{(i)}}\end{equation}
\begin{equation}\hat{y}^{(i)}=\sigma(z^{(i)})\end{equation}

在上述公式中:

  • 右上角的角标 \cdot^{(i)} 表示第 i 条样本;

  • \vec{x^{(i)}} 表示第 i 条样本的输入特征,是一个向量;假设输入样本特征的维度为 d_{input},则 \vec{x^{(i)}} = [x^{(i)}_1, x^{(i)}_2, ..., x^{(i)}_{d_{input}}]

  • \mathbf{W} 是权重矩阵;

  • \sigma(\cdot) 表示sigmoid函数;

  • z^{(i)} 表示全连接层的输出,同时也是sigmoid层的输入,其是一个标量;

  • \hat{y}^{(i)} 是sigmoid层的输出,也是一个标量,表示的是模型对第 i 条样本预测为正例的概率,使用该值和 y^{(i)} 比较计算损失;

1.2 损失函数中的符号定义#

损失函数的公式,如下所示:

\begin{equation}L^{(i)}=\frac{1}{2}(y^{(i)} - \hat{y}^{(i)})^2\end{equation}
\begin{equation}L=\frac{1}{N} \sum_{i=1}^N L^{(i)}\end{equation}

上述公式中的公式(4)只是对每条样本的损失 L^{(i)} 求均值,N 表示样本的总数,比较简单。后面的分析和求导将只对公式(3)进行。

1.3 求导#

1.3.1 说明 \frac{\partial L^{(i)}}{\partial w}\frac{\partial L^{(i)}}{\partial z^{(i)}} 的区别#

根据链式求导法则,\frac{\partial L^{(i)}}{\partial w} 可以分为三部分,如下式所示:

\begin{equation} \frac{\partial L^{(i)}}{\partial w} = \frac{\partial L^{(i)}}{\partial \hat{y}^{(i)}} \frac{\partial \hat{y}^{(i)}}{\partial z^{(i)}} \frac{\partial z^{(i)}}{\partial w} \end{equation}

\frac{\partial L^{(i)}}{\partial z^{(i)}} 可以分为两部分,如下式所示:

\begin{equation} \frac{\partial L^{(i)}}{\partial z^{(i)}} = \frac{\partial L^{(i)}}{\partial \hat{y}^{(i)}} \frac{\partial \hat{y}^{(i)}}{\partial z^{(i)}} \end{equation}

它们之间相差了一个 \frac{\partial z^{(i)}}{\partial w},由前向传播的公式可知,这个导数取决于sigmoid层之前的网络结构是什么样的。在本文的 1.1 部分假设是一个全连接层,所以其导数为 \frac{\partial z^{(i)}}{\partial w}=\vec{x^{(i)}};但其也可以是CNN层、RNN层、Transformer层等等,所以对于不同的网络 \frac{\partial z^{(i)}}{\partial w} 有着不同的求解方式,其并不属于二分类交叉熵损失部分的求导。所以本文后面在求导时求解的是 \frac{\partial L^{(i)}}{\partial z^{(i)}}

1.3.2 求 \frac{\partial L^{(i)}}{\partial z^{(i)}}#

求导过程非常简单,如下所示:

\begin{equation}\begin{split} \frac{\partial L^{(i)}}{\partial z^{(i)}} &= \frac{\partial L^{(i)}}{\partial \hat{y}^{(i)}} \frac{\partial \hat{y}^{(i)}}{\partial z^{(i)}} \\ &= (\hat{y}^{(i)} - y^{(i)}) \sigma^{\prime}(z^{(i)}) \end{split}\end{equation}

其中 \sigma^{\prime}( \cdot ) 表示sigmoid函数的导数;

2、均方差损失与交叉熵损失的异同#

在文章 二分类交叉熵损失求导 中已经求解出了二分类交叉熵损失的导数,下面将均方差损失的导数和交叉熵损失的导数都放到这里进行对比:

\begin{equation}\begin{split} &\frac{\partial L^{(i)}}{\partial z^{(i)}}=\hat{y}^{(i)} - y^{(i)} \qquad // \text{二分类交叉熵损失的梯度} \\ &\frac{\partial L_i}{\partial z^{(i)}}=(\hat{y}^{(i)} - y^{(i)}) \sigma^{\prime}(z^{(i)}) \qquad// \text{二分类均方差损失的梯度} \end{split}\end{equation}

先分析Sigmoid函数的性质。下图为Sigmoid函数的图像,可以看出当 z_i 比较小或者比较大时, \sigma^{\prime}(z_i) 的值(即下图曲线的斜率)都趋于0;也就是说当 z_i 比较小或者比较大时,其梯度也是趋于0的。

在做梯度下降时,希望当离目标较远时,每次更新的步长要大一些,可以快速收敛;当离目标较近时,每次更新的步长要小一些,可以防止在目标值附近震荡;

对比交叉熵损失和均方差损失的梯度,可以看出它们两个只差一个 \sigma^{\prime}(z^{(i)}) ,所以有如下结论:

  • 使用交叉熵损失做二分类任务时,其梯度为期望输出与实际输出的差值,当距离目标点越远时,该差值越大,梯度越大;当距离目标点越近时,该差值越小,梯度越小;这符合之前所希望的;

  • 使用均方差损失做而分类任务时,由于其梯度中包含了 \sigma^{\prime}(z^{(i)}) 这一项,当距离目标点比较远时,梯度较小(趋于0);当距离目标点较近时,梯度也较小(趋于0),不利于优化;

3、多分类任务时的异同#

上面两部分的分析是使用交叉熵损失和均方差损失做二分类任务时的区别,在多分类任务上分析过程是类似的,步骤为:

  • 求解交叉熵损失在多分类时的梯度;
  • 求解均方差损失在多分类时的梯度;
  • 对比分析两个梯度在做梯度下降时的效果;

多分类交叉熵的梯度前面已经求解过了,详情见: 多分类交叉熵损失求导

均方差损失在做多分类任务时,梯度的求解非常复杂,奈何数学不好,先空着,以后再看;

4、总结#

本文主要是对比了交叉熵损失和均方差损失在做 分类任务 时的异同。在二分类任务中,均方差损失的梯度中由于存在 \sigma^{\prime}(z^{(i)}) 这一项,导致其在距离目标点非常远和非常近时的梯度都较小,使得优化效果较差。另外多分类任务的均方差损失的梯度由于分析比较复杂,没有做详细分析。