[隐藏左侧目录栏][显示左侧目录栏]

Pre Norm 与 Post Norm#

1、公式#

Post Norm 的公式:

\begin{equation}x_{t+1} = \text{Norm}(x_t + F_t(x_t))\end{equation}

Pre Norm 的公式:

\begin{equation}x_{t+1} = x_t + F_t(\text{Norm}(x_t))\end{equation}

2、实验中观察到的现象#

这是一个总共12层,有125M参数的模型。下图展示的是在模型训练的早期(可以看出step都是小于600的),第0层、第1层、第6层、第10层、第11层的梯度的平均L1范数值。

可以看出,使用 Post Norm 时,第0层和第1层的梯度较小,第10层和第11层的梯度较大。也就是靠前的层梯度较小,靠后的层梯度较大。并且有其他研究表明,靠前层的梯度还会随着训练的进行还会迅速减小,导致靠前的层之后就得不到充分的训练。

使用 Pre Norm 时,第0层和第1层的梯度较大,第10层和第11层的梯度较小。也就是靠前的层梯度较大,靠后的层梯度较小。

就理想情况来说,无论是"靠前层梯度小靠后层梯度大",还是"靠前层梯度大靠后层梯度小",都不是想要的,最想要的是:每一层的梯度是相同的(这里不是说梯度完全相同,而是梯度的统计值相同,比如每层梯度的L1范数值相同)。做这方面的工作有很多,比如 NormFormer 是在 Pre Norm 的基础上做优化,通过缓解 "靠前层梯度大靠后层梯度小" 的问题达到每层梯度相同;DeepNet 是在 Post Norm 的基础上做优化,通过缓解 "靠前层梯度小靠后层梯度大" 的问题达到每层梯度相同;

上述现象出自论文: NormFormer: Improved Transformer Pretraining with Extra Normalization

上面是从实验中观察到的现象,下面从理论方面分析一下产生该现象的原因。

3、残差连接与 Normalize 的联系#

残差连接这种设计的思路是给比较靠前面的层一条绿色通道,反向传播时梯度可以沿着这条绿色通道直接传递到比较靠前面的层,其前向传播的公式为:x + F(x)

记随机变量 x 的方差为 \sigma_1^2,记随机变量 F(x) 的方差为 \sigma_2^2,并且假设这两个随机变量是相互独立的,那么有随机变量 x+F(x) 的方差为 \sigma_1^2+\sigma_2^2

这里利用了方差的性质:如果 xy 是独立的随机变量,那么有 \text{Var}(x+y) = \text{Var}(x) + \text{Var}(y)

更多方差的性质见: https://zhuanlan.zhihu.com/p/161505873

可以看出,由于方差本身必然是非负的,在残差连接的结构设计下,随着层数的加深,方差是越来越大的,所以需要有一个策略控制方差在一定的范围内。而 normalize 就是能够控制方差很好的一种方法,下面分别分析 Pre Norm 和 Post Norm 在控制方差方面的效果,以及各自自身的缺陷。

4、Post Norm#

Post Norm 的公式为 x_{t+1} = \text{Norm}(x_t + F_t(x_t))

在对其分析之前,假设初始状态下 xF(x) 的方差都是1,所以此时 x+F(x) 的方差为2。经过 Post Norm 之后 x+F(x) 的方差又变为了1。

也就是说对 x+F(x) 做 Norm 等同于对 x+F(x) 除以 \sqrt{2},如下公式所示:

\begin{equation}x_{t+1} = \text{Norm}(x_t + F_t(x_t)) \rightarrow \frac{x_t+F_t(x_t)}{\sqrt{2}}\end{equation}

基于上述公式,一直递归下去,公式如下所示:

\begin{equation}\begin{split} x_{t+1} &= \frac{x_t + F_t(x_t)}{\sqrt{2}} \\ &= \frac{x_t}{2^{1/2}} + \frac{F_t(x_t)}{2^{1/2}} \\ & = \frac{x_{t-1}}{2^{2/2}} + \frac{F_{t-1}(x_{t-1})}{2^{2/2}} + \frac{F_t(x_t)}{2^{1/2}} \\ &= \frac{x_{t-2}}{2^{3/2}} + \frac{F_{t-2}(x_{t-2})}{2^{3/2}} + \frac{F_{t-1}(x_{t-1})}{2^{2/2}} + \frac{F_t(x_t)}{2^{1/2}} \\ &= ... ... \\ &= \frac{x_1}{2^{t/2}} + \frac{F_1(x_1)}{2^{t/2}} + \frac{F_2(x_2)}{2^{(t-1)/2}} + ... ... + \frac{F_{t-1}(x_{t-1})}{2^{2/2}} + \frac{F_t(x_t)}{2^{1/2}} \end{split}\end{equation}

上述推导出的公式的最后一行的含义为:模型的第 t+1 层的输出结果是由模型第1层的原始输入x_1、模型第1层的输出结果F_1(x_1)、模型第2层的输出结果F_2(x_2)直到模型第t层的输出结果F_t(x_t)加权求和得到的。

主要就是看这个加权中的权限值,第1层的原始输入、以及从第1层到第t层的输出结果的权重为 \Big\{\frac{1}{2^{t/2}}, \frac{1}{2^{t/2}}, \frac{1}{2^{(t-1)/2}}, ... ..., \frac{1}{2^{2/2}}, \frac{1}{2^{1/2}}\Big\}。可以明显的看出越靠前的层,其衰减的越严重。

综合来说:

  • 无论模型的哪一层,经过 Post Norm 之后其方差都被严格的控制到一个固定的范围;
  • Post Norm 会导致越靠前的层衰减的越严重,这和残差连接设计的初衷是相悖的。另外在实际使用中,使用了 Post Norm 之后模型训练起来也比较困难,必须要使用 warmup 等机制来保证模型收敛。

5、Pre Norm#

Pre Norm 的公式为 x_{t+1} = x_t + F_t(\text{Norm}(x_t))

基于该公式,一直递归下去,得到的公式如下所示:

\begin{equation}\begin{split} x_{t+1} &= x_t + F_t(\text{Norm}(x_t)) \\ &= x_{t-1} + F_{t-1}(\text{Norm}(x_{t-1})) + F_t(\text{Norm}(x_t)) \\ &= x_{t-2} + F_{t-2}(\text{Norm}(x_{t-2})) + F_{t-1}(\text{Norm}(x_{t-1})) + F_t(\text{Norm}(x_t)) \\ &= ... ... \\ &= x_1 + F_1(\text{Norm}(x_1)) + F_2(\text{Norm}(x_2)) + F_3(\text{Norm}(x_3)) + ... ... + F_{t-1}(\text{Norm}(x_{t-1})) + F_t(\text{Norm}(x_t)) \end{split}\end{equation}

由上述公式推导的最后一行可以看出,第 t+1 层模型的输出是由:第1层的原属输x_1、以及各层模型的输出直接求和得到的。相比于 Post Norm 对越靠前的层衰减越严重,这里的 Pre Norm 对待所有层的输出一视同仁,直接求和。

另外,容易看出随着层数的加深,整体的方差是在不断的增大的,所以使用 Pre Norm 时一般在最后一层之后再加个总的 Norm 层。

相比于 Post Norm,Pre Norm 能够真正的给靠前的层一个绿色通道,直达模型最终层。这样在反向传播时就不存在靠前的层梯度非常小的问题,所以 Pre Norm 训练起来要比 Post Norm 要容易训练。

每一层 F_1, F_2, ..., F_{t-1}, F_t 的输入都是经过 Norm 的,所以其输入都可看作方差为1的随机变量,而每一层的模型结构又是相同的,所以从统计上来看每一层的输出结果方差也是相同的。基于此,有一种观点认为当层数 t 比较大时,单层的输出对总输出的贡献是小量,此时有下述公式成立:

\begin{equation}\begin{split} x_{t+2} &= x_{t+1} + F_{t+1}(\text{Norm}(x_{t+1})) \\ &= x_t + F_t(\text{Norm}(x_t)) + F_{t+1}(\text{Norm}(x_{t+1})) \\ & \approx x_t + 2 \cdot F_t(\text{Norm}(x_t)) \end{split}\end{equation}

这样一个 t+2 层的网络就变成了一个更宽一些的 t+1 层的网络。由于现在整个深度学习都是使用反向传播,所以在相同的参数量下,宽而浅的网络训练起来是要比深而窄的网络要更好训练的,所以使用 Pre Norm 的网络比使用 Post Norm 的网络更好训练一些。

6、两者对比#

其实在上面的两个小节中,所以的内容都已经分析过了,这个小节只是做一下汇总。

Pre Norm Post Norm
好训练 不好训练,基本必须要和warmup一起使用
对原始输入x和各层网络的输出一视同仁,平权相加 非平权,越靠前的层衰减越严重
在对模型更新时,靠前层梯度大,靠后层的梯度小 在对模型更新时,靠前层梯度小,靠后层的梯度大

Reference#