多分类交叉熵损失求导#
1.由于不想每个向量都打上右上角的转置符号,所以本文的向量都是行向量,其与列向量没有本质区别;
2.本文的推导过程省略了偏置项 b;
多分类也就是softmax的损失的求导主要问题在于符号的定义上,符号表示说清楚了,求导是比较容易的。
1、前向传播过程中的符号定义#
模型假设为非常简单的:一个全连接层后接softmax做分类。那么其前向传播的公式为:
在上述公式中:
-
右上角的角标 ⋅(i) 表示第 i 条样本;
-
→x(i) 表示第 i 条样本的输入特征,是一个向量;假设输入样本特征的维度为 dinput,则 →x(i)={x(i)1,x(i)2,...,x(i)dinput};
-
W 是权重矩阵,假设其维度为 (dinput,d);那么由于 →x(i) 维度为 (1,dinput),W 的维度为 (dinput,d),可以得出 →z(i) 的维度为 (1,d)。(由于本文中都是行向量,所以上面公式(1)中写作 →x(i)W 更合适)
-
→z(i) 表示全连接层的输出,同时也是softmax层的输入,按照上面的假设,其维度为 (1,d),即 →z(i)={z(i)1,z(i)2,...,z(i)d};
-
→ˆy(i) 表示softmax层的输出,在一般的分类模型中也是最后一层的输出,使用该值和 →y(i) 比较计算损失;softmax运算并不会改变向量的维度,所以 →ˆy(i) 的维度也是 d(从这可以看出这个维度 d 就是该多分类任务的类别个数),有 →ˆy(i)={ˆy(i)1,ˆy(i)2,...,ˆy(i)d};
-
右下角的角标表示第几个维度的元素,所以仅向量/矩阵中的某个元素才有右下角的角标,向量/矩阵是肯定不会有右下角的角标的;
2、损失函数中的符号定义#
损失函数的公式,如下所示:
上述公式中的公式(3)没有任何问题,只是对每条样本的损失求均值,N 表示样本的总数,后面的分析和求导将只对公式(2)进行。
这里的公式(2)和常见的多分类任务的损失函数形式上不太一致,常见的损失函数一般都是如下公式(4)的形式,下面专门说明一下这两种形式的损失函数:
容易知道对于多分类问题,其每条样本的标签对应的向量中,所有的元素里面有且仅有一个元素为1,其他元素都是0,即向量 →y(i)={y(i)1,y(i)2,...,y(i)d} 中仅有一个元素为1,其他元素都为0。为了不失一般性,假设第 j 个元素为1,则有:
将 →ˆy(i) 和 →y(i) 的各个元素代入到公式(2)中:
上述公式(6)中第2行推导出第3行的原因是只有 y(i)j 为1,其他的像 y(i)1、y(i)2、y(i)d 等都为0;公式(6)最后的结果就是公式(4)的形式,只不过它把右下角的下标省略了。需要注意的是:公式(4)中的 ˆy(i) 是一个标量,它是公式(1)中的向量 →ˆy(i) 中的一个元素,这是一个容易混淆的地方。
至此,所有的符号说明完成,后面在求导时使用的损失函数是公式(2)的形式。
3、多分类交叉熵损失求导#
3.1 说明 ∂L(i)∂w 与 ∂L(i)∂z(i) 的区别#
根据链式求导法则,∂L(i)∂w 可以分为三部分,如下式所示:
∂L(i)∂→z(i) 可以分为两部分,如下式所示:
它们之间相差了一个 ∂→z(i)∂w,由前向传播的公式可知,这个导数取决于softmax层之前的网络结构是什么样的。在本文的第一部分假设是一个全连接层,所以其导数为 ∂→z(i)∂w=→x(i);但其也可以是CNN层、RNN层、Transformer层等等,所以对于不同的网络 ∂→z(i)∂w 有着不同的求解方式,其并不属于多分类交叉熵损失部分的求导。所以本文后面在求导时求解的是 ∂L(i)∂→z(i)。
3.2 总体分析#
先总体分析一下:
-
∂L(i)∂→z(i) 是标量对向量求导(标量对向量求导:标量分别对向量中的每个元素求导,最终结果是一个向量,维度与原向量相同),结果是一个向量,维度为 (1,d);
-
∂L(i)∂→ˆy(i) 是标量对向量求导,结果是一个向量,维度为 (1,d);
-
∂→ˆy(i)∂→z(i) 是向量对向量求导,结果是Jacobi矩阵,维度为 (d,d);
-
∂L(i)∂→ˆy(i) 与 ∂→ˆy(i)∂→z(i) 相乘,是维度为 (1,d) 的向量与维度为 (d,d) 的矩阵相乘,最终结果是一个维度为 (1,d) 的向量;
总体来看:在链式求导公式(8)中的各个维度是能够对应上的,没有问题。
3.3 求导#
先求 ∂L(i)∂→ˆy(i),求解结果如下:
再求 ∂→ˆy(i)∂→z(i),这个在文章 Softmax函数求导 中已经求解过了,直接使用其结论:
然后将 ∂L(i)∂→ˆy(i) 与 ∂→ˆy(i)∂→z(i) 相乘就得到了最终结果;可以看到 ∂L(i)∂→ˆy(i) 中只有第 j 个元素非0,所以在 ∂→ˆy(i)∂→z(i) 中只有第 j 行会被使用到,其他行都与0相乘消去了:
求解过程很复杂,求解的结果却很优雅。因为softmax函数和损失本身就是精心设计过的,才能在计算梯度时非常简单、高效。
4、总结#
本文主要是对softmax交叉熵损失做求导,用于进一步理解在多分类任务中反向传播的细节过程。