[隐藏左侧目录栏][显示左侧目录栏]

多分类交叉熵损失求导#

1.由于不想每个向量都打上右上角的转置符号,所以本文的向量都是行向量,其与列向量没有本质区别;

2.本文的推导过程省略了偏置项 b

多分类也就是softmax的损失的求导主要问题在于符号的定义上,符号表示说清楚了,求导是比较容易的。

1、前向传播过程中的符号定义#

模型假设为非常简单的:一个全连接层后接softmax做分类。那么其前向传播的公式为:

z(i)=Wx(i)ˆy(i)=softmax(z(i))

在上述公式中:

  • 右上角的角标 (i) 表示第 i 条样本;

  • x(i) 表示第 i 条样本的输入特征,是一个向量;假设输入样本特征的维度为 dinput,则 x(i)={x(i)1,x(i)2,...,x(i)dinput}

  • W 是权重矩阵,假设其维度为 (dinput,d);那么由于 x(i) 维度为 (1,dinput)W 的维度为 (dinput,d),可以得出 z(i) 的维度为 (1,d)。(由于本文中都是行向量,所以上面公式(1)中写作 x(i)W 更合适)

  • z(i) 表示全连接层的输出,同时也是softmax层的输入,按照上面的假设,其维度为 (1,d),即 z(i)={z(i)1,z(i)2,...,z(i)d}

  • ˆy(i) 表示softmax层的输出,在一般的分类模型中也是最后一层的输出,使用该值和 y(i) 比较计算损失;softmax运算并不会改变向量的维度,所以 ˆy(i) 的维度也是 d(从这可以看出这个维度 d 就是该多分类任务的类别个数),有 ˆy(i)={ˆy(i)1,ˆy(i)2,...,ˆy(i)d}

  • 右下角的角标表示第几个维度的元素,所以仅向量/矩阵中的某个元素才有右下角的角标,向量/矩阵是肯定不会有右下角的角标的;

2、损失函数中的符号定义#

损失函数的公式,如下所示:

L(i)=dk=1y(i)klnˆy(i)k
L=1NNi=1L(i)

上述公式中的公式(3)没有任何问题,只是对每条样本的损失求均值,N 表示样本的总数,后面的分析和求导将只对公式(2)进行。

这里的公式(2)和常见的多分类任务的损失函数形式上不太一致,常见的损失函数一般都是如下公式(4)的形式,下面专门说明一下这两种形式的损失函数:

L(i)=lnˆy(i)

容易知道对于多分类问题,其每条样本的标签对应的向量中,所有的元素里面有且仅有一个元素为1,其他元素都是0,即向量 y(i)={y(i)1,y(i)2,...,y(i)d} 中仅有一个元素为1,其他元素都为0。为了不失一般性,假设第 j 个元素为1,则有:

y(i)={y(i)1,y(i)2,...,y(i)j1,y(i)j,y(i)j+1,...,y(i)d}={0,0,...,0,1,0,...,0}

ˆy(i)y(i) 的各个元素代入到公式(2)中:

L(i)=dk=1y(i)klnˆy(i)k=[y(i)1lnˆy(i)1+y(i)2lnˆy(i)2+...+y(i)jlnˆy(i)j+...+y(i)dlnˆy(i)d]=y(i)jlnˆy(i)j=lnˆy(i)j

上述公式(6)中第2行推导出第3行的原因是只有 y(i)j 为1,其他的像 y(i)1y(i)2y(i)d 等都为0;公式(6)最后的结果就是公式(4)的形式,只不过它把右下角的下标省略了。需要注意的是:公式(4)中的 ˆy(i) 是一个标量,它是公式(1)中的向量 ˆy(i) 中的一个元素,这是一个容易混淆的地方。

至此,所有的符号说明完成,后面在求导时使用的损失函数是公式(2)的形式。

3、多分类交叉熵损失求导#

3.1 说明 L(i)wL(i)z(i) 的区别#

根据链式求导法则,L(i)w 可以分为三部分,如下式所示:

L(i)w=L(i)ˆy(i)ˆy(i)z(i)z(i)w

L(i)z(i) 可以分为两部分,如下式所示:

L(i)z(i)=L(i)ˆy(i)ˆy(i)z(i)

它们之间相差了一个 z(i)w,由前向传播的公式可知,这个导数取决于softmax层之前的网络结构是什么样的。在本文的第一部分假设是一个全连接层,所以其导数为 z(i)w=x(i);但其也可以是CNN层、RNN层、Transformer层等等,所以对于不同的网络 z(i)w 有着不同的求解方式,其并不属于多分类交叉熵损失部分的求导。所以本文后面在求导时求解的是 L(i)z(i)

3.2 总体分析#

先总体分析一下:

  • L(i)z(i) 是标量对向量求导(标量对向量求导:标量分别对向量中的每个元素求导,最终结果是一个向量,维度与原向量相同),结果是一个向量,维度为 (1,d)

  • L(i)ˆy(i) 是标量对向量求导,结果是一个向量,维度为 (1,d)

  • ˆy(i)z(i) 是向量对向量求导,结果是Jacobi矩阵,维度为 (d,d)

  • L(i)ˆy(i)ˆy(i)z(i) 相乘,是维度为 (1,d) 的向量与维度为 (d,d) 的矩阵相乘,最终结果是一个维度为 (1,d) 的向量;

总体来看:在链式求导公式(8)中的各个维度是能够对应上的,没有问题。

3.3 求导#

先求 L(i)ˆy(i),求解结果如下:

L(i)ˆy(i)=[L(i)ˆy(i)1,L(i)ˆy(i)2,...,L(i)ˆy(i)j,...,L(i)ˆy(i)d]=[0,0,...,1ˆy(i)j,...,0]

再求 ˆy(i)z(i),这个在文章 Softmax函数求导 中已经求解过了,直接使用其结论:

ˆy(i)z(i)=[ˆy(i)1(ˆy(i)1)2ˆy(i)1ˆy(i)2ˆy(i)1ˆy(i)dˆy(i)2ˆy(i)1ˆy(i)2(ˆy(i)2)2ˆy(i)2ˆy(i)dˆy(i)dˆy(i)1ˆy(i)dˆy(i)2ˆy(i)d(ˆy(i)d)2]

然后将 L(i)ˆy(i)ˆy(i)z(i) 相乘就得到了最终结果;可以看到 L(i)ˆy(i) 中只有第 j 个元素非0,所以在 ˆy(i)z(i) 中只有第 j 行会被使用到,其他行都与0相乘消去了:

L(i)z(i)=L(i)ˆy(i)ˆy(i)z(i)=1ˆy(i)j[ˆy(i)jˆy(i)1,ˆy(i)jˆy(i)2,...,ˆy(i)j(ˆy(i)j)2,...,ˆy(i)jˆy(i)d]=[ˆy(i)1,ˆy(i)2,...,ˆy(i)j1,...,ˆy(i)d]=ˆy(i)y(i)

求解过程很复杂,求解的结果却很优雅。因为softmax函数和损失本身就是精心设计过的,才能在计算梯度时非常简单、高效。

4、总结#

本文主要是对softmax交叉熵损失做求导,用于进一步理解在多分类任务中反向传播的细节过程。

Reference#