[隐藏左侧目录栏][显示左侧目录栏]

Pre Norm 与 Post Norm#

1、公式#

Post Norm 的公式:

xt+1=Norm(xt+Ft(xt))

Pre Norm 的公式:

xt+1=xt+Ft(Norm(xt))

2、实验中观察到的现象#

这是一个总共12层,有125M参数的模型。下图展示的是在模型训练的早期(可以看出step都是小于600的),第0层、第1层、第6层、第10层、第11层的梯度的平均L1范数值。

可以看出,使用 Post Norm 时,第0层和第1层的梯度较小,第10层和第11层的梯度较大。也就是靠前的层梯度较小,靠后的层梯度较大。并且有其他研究表明,靠前层的梯度还会随着训练的进行还会迅速减小,导致靠前的层之后就得不到充分的训练。

使用 Pre Norm 时,第0层和第1层的梯度较大,第10层和第11层的梯度较小。也就是靠前的层梯度较大,靠后的层梯度较小。

就理想情况来说,无论是"靠前层梯度小靠后层梯度大",还是"靠前层梯度大靠后层梯度小",都不是想要的,最想要的是:每一层的梯度是相同的(这里不是说梯度完全相同,而是梯度的统计值相同,比如每层梯度的L1范数值相同)。做这方面的工作有很多,比如 NormFormer 是在 Pre Norm 的基础上做优化,通过缓解 "靠前层梯度大靠后层梯度小" 的问题达到每层梯度相同;DeepNet 是在 Post Norm 的基础上做优化,通过缓解 "靠前层梯度小靠后层梯度大" 的问题达到每层梯度相同;

上述现象出自论文: NormFormer: Improved Transformer Pretraining with Extra Normalization

上面是从实验中观察到的现象,下面从理论方面分析一下产生该现象的原因。

3、残差连接与 Normalize 的联系#

残差连接这种设计的思路是给比较靠前面的层一条绿色通道,反向传播时梯度可以沿着这条绿色通道直接传递到比较靠前面的层,其前向传播的公式为:x+F(x)

记随机变量 x 的方差为 σ21,记随机变量 F(x) 的方差为 σ22,并且假设这两个随机变量是相互独立的,那么有随机变量 x+F(x) 的方差为 σ21+σ22

这里利用了方差的性质:如果 xy 是独立的随机变量,那么有 Var(x+y)=Var(x)+Var(y)

更多方差的性质见: https://zhuanlan.zhihu.com/p/161505873

可以看出,由于方差本身必然是非负的,在残差连接的结构设计下,随着层数的加深,方差是越来越大的,所以需要有一个策略控制方差在一定的范围内。而 normalize 就是能够控制方差很好的一种方法,下面分别分析 Pre Norm 和 Post Norm 在控制方差方面的效果,以及各自自身的缺陷。

4、Post Norm#

Post Norm 的公式为 xt+1=Norm(xt+Ft(xt))

在对其分析之前,假设初始状态下 xF(x) 的方差都是1,所以此时 x+F(x) 的方差为2。经过 Post Norm 之后 x+F(x) 的方差又变为了1。

也就是说对 x+F(x) 做 Norm 等同于对 x+F(x) 除以 2,如下公式所示:

xt+1=Norm(xt+Ft(xt))xt+Ft(xt)2

基于上述公式,一直递归下去,公式如下所示:

xt+1=xt+Ft(xt)2=xt21/2+Ft(xt)21/2=xt122/2+Ft1(xt1)22/2+Ft(xt)21/2=xt223/2+Ft2(xt2)23/2+Ft1(xt1)22/2+Ft(xt)21/2=......=x12t/2+F1(x1)2t/2+F2(x2)2(t1)/2+......+Ft1(xt1)22/2+Ft(xt)21/2

上述推导出的公式的最后一行的含义为:模型的第 t+1 层的输出结果是由模型第1层的原始输入x1、模型第1层的输出结果F1(x1)、模型第2层的输出结果F2(x2)直到模型第t层的输出结果Ft(xt)加权求和得到的。

主要就是看这个加权中的权限值,第1层的原始输入、以及从第1层到第t层的输出结果的权重为 {12t/2,12t/2,12(t1)/2,......,122/2,121/2}。可以明显的看出越靠前的层,其衰减的越严重。

综合来说:

  • 无论模型的哪一层,经过 Post Norm 之后其方差都被严格的控制到一个固定的范围;
  • Post Norm 会导致越靠前的层衰减的越严重,这和残差连接设计的初衷是相悖的。另外在实际使用中,使用了 Post Norm 之后模型训练起来也比较困难,必须要使用 warmup 等机制来保证模型收敛。

5、Pre Norm#

Pre Norm 的公式为 xt+1=xt+Ft(Norm(xt))

基于该公式,一直递归下去,得到的公式如下所示:

xt+1=xt+Ft(Norm(xt))=xt1+Ft1(Norm(xt1))+Ft(Norm(xt))=xt2+Ft2(Norm(xt2))+Ft1(Norm(xt1))+Ft(Norm(xt))=......=x1+F1(Norm(x1))+F2(Norm(x2))+F3(Norm(x3))+......+Ft1(Norm(xt1))+Ft(Norm(xt))

由上述公式推导的最后一行可以看出,第 t+1 层模型的输出是由:第1层的原属输x1、以及各层模型的输出直接求和得到的。相比于 Post Norm 对越靠前的层衰减越严重,这里的 Pre Norm 对待所有层的输出一视同仁,直接求和。

另外,容易看出随着层数的加深,整体的方差是在不断的增大的,所以使用 Pre Norm 时一般在最后一层之后再加个总的 Norm 层。

相比于 Post Norm,Pre Norm 能够真正的给靠前的层一个绿色通道,直达模型最终层。这样在反向传播时就不存在靠前的层梯度非常小的问题,所以 Pre Norm 训练起来要比 Post Norm 要容易训练。

每一层 F1, F2, ..., Ft1, Ft 的输入都是经过 Norm 的,所以其输入都可看作方差为1的随机变量,而每一层的模型结构又是相同的,所以从统计上来看每一层的输出结果方差也是相同的。基于此,有一种观点认为当层数 t 比较大时,单层的输出对总输出的贡献是小量,此时有下述公式成立:

xt+2=xt+1+Ft+1(Norm(xt+1))=xt+Ft(Norm(xt))+Ft+1(Norm(xt+1))xt+2Ft(Norm(xt))

这样一个 t+2 层的网络就变成了一个更宽一些的 t+1 层的网络。由于现在整个深度学习都是使用反向传播,所以在相同的参数量下,宽而浅的网络训练起来是要比深而窄的网络要更好训练的,所以使用 Pre Norm 的网络比使用 Post Norm 的网络更好训练一些。

6、两者对比#

其实在上面的两个小节中,所以的内容都已经分析过了,这个小节只是做一下汇总。

Pre Norm Post Norm
好训练 不好训练,基本必须要和warmup一起使用
对原始输入x和各层网络的输出一视同仁,平权相加 非平权,越靠前的层衰减越严重
在对模型更新时,靠前层梯度大,靠后层的梯度小 在对模型更新时,靠前层梯度小,靠后层的梯度大

Reference#