将“softmax+交叉熵”推广到多标签分类问题#
单标签分类loss#
多分类任务的softmax loss,为了和后面的多标签分类在叫法上保持一致性,也称为"单标签分类"的loss:
对上述公式的最后一行进行分析,其中 s_i 表示非目标类别得分,s_t 表示目标类别得分,s_i 和 s_t 的取值范围理论上是没有限制的,正负都有可能。
由于取对数、求和、求e的幂都是都是单调增的,所以极小化上述公式的最后一行就等同于极小化 s_i-s_t。极小化 s_i-s_t 也就是说希望 s_i < s_t。由此可以得出如下结论:
该loss的目标是:对于每条数据,每个目标类别的得分都大于每个非目标类的得分;
多标签分类loss#
将单标签分类loss的目标扩展到多标签分类:对于每条数据,每个目标类别的得分都大于每个非目标类的得分;
对应的loss公式为:
其中\Omega_{neg}、\Omega_{pos}分别是正负样本的类别集合;
这个loss很容易理解,并且直接使用这个loss就可以进行训练并且可以收敛;
但是这个loss只能用于训练,不能用于推断,在推断时不知道每条数据有多少个类别是pos,有多少个类别是neg;
用于多标签分类#
增加一个新类别 s_0 作为阈值:
将新类别s_0的输出得分恒定为0,目的有两个:
- 一是可以将公式(4)化简;
- 另一个更重要的目的是这样就可以进行推断了。推断时,所有得分大于0的为pos,得分小于0的为neg;
当然这里将s_0的输出得分恒定为其他值,比如1也没有问题,但是代码实现起来麻烦,另外推理时的得分要跟1做比较,得分小于1的表示neg,得分大于1的表示pos。
对公式(3)不做任何化简,仅将s_0设置为0,就已经能够训练和推理了,可以思考一下这里如何实现训练和推理的细节。
将 s_0=0 代入公式(3)并进行化简:
在代码中实现时就按照公式 \log \big ( 1 + \sum_{i \in \Omega_{neg}} e^{s_i} \big ) + \log \big ( 1 + \sum_{j \in \Omega_{pos}} e^{-s_j} \big ) 进行实现。
Pytorch代码#
def multi_label_categorical_cross_entropy(y_true, y_pred):
y_pred = (1 - 2 * y_true) * y_pred
# 计算公式中的 s_i
y_pred_neg = y_pred - y_true * 1e12
# 计算公式中的 -s_j
y_pred_pos = y_pred - (1 - y_true) * 1e12
# 补上得分恒为0的 s_0 类别
zeros = torch.zeros_like(y_pred[..., :1])
y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
# 计算 logsumexp
neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
return torch.mean(neg_loss + pos_loss)
使用该loss时遇到的问题#
任务背景#
任务总共有10个类别:address, book, company, game, government, movie, name, organization, position, scene;
训练集中有5000条数据标注了其中的五个类别:address, book, company, game, government;
训练集中另外5000条数据标注了另外的五个类别:movie, name, organization, position, scene;
问题说明#
计算loss时需要增加一个mask把没有标注的类别去掉;
常规的mask写法如下:
pred = pred * label_mask
在使用该loss时,以上用法有问题;正确的用法是:
pred = pred - (1 - label_mask) * 1e12
原因是对于logsumexp来说,输入0并不是不产生loss,输入负无穷才是不产生loss,如下:
from scipy.special import logsumexp
print(logsumexp([0.1, 0.2, 0.3]) == logsumexp([0.1, 0.2, 0.3, -1e12]))
print(logsumexp([0.1, 0.2, 0.3]) == logsumexp([0.1, 0.2, 0.3, 0]))
输出:
True
False
参考#
苏剑林. (Apr. 25, 2020). 《将“softmax+交叉熵”推广到多标签分类问题 》[Blog post]. Retrieved from https://kexue.fm/archives/7359