[隐藏左侧目录栏][显示左侧目录栏]

中间激活值显存分析#

中间激活值(intermediate activations):是在前向传播的过程中,为了让后向传播完成计算,所需要保留的模型中间结果。

1、什么是中间激活值?#

以一个四个 Linear 的模型结构为例进行说明。其前向传播和损失函数的公式如下所示:

\begin{split} x_1 &= W_1 x + b_1 \\ x_2 &= W_2 x_1 + b_2 \\ x_3 &= W_3 x_2 + b_3 \\ x_4 &= W_4 x_3 + b_4 \\ l &= (y - x_4)^2 \end{split}

在该公式中:xy 为数据的特征和标签;W_1b_1W_2b_2W_3b_3W_4b_4 为四个 Linear 层的权重和偏置;x_1x_2x_3x_4 都是计算过程中的中间状态。

反向传播过程中要对权重进行更新,也就是求损失相对于 W_1W_2W_3W_4 的偏导,按照链式求导法则得到公式如下:

\begin{split} \frac{\partial l}{\partial W_4} &= \frac{\partial l}{\partial x_4} \cdot \frac{\partial x_4}{\partial W_4} = \Bigg[ -2(y-x_4)\Bigg] \cdot x_3 \\ \frac{\partial l}{\partial W_3} &= \frac{\partial l}{\partial x_4} \cdot \frac{\partial x_4}{\partial x_3} \cdot \frac{\partial x_3}{\partial W_3} = \Bigg[ [-2(y-x_4)] \cdot W_4 \Bigg] \cdot x_2 \\ \frac{\partial l}{\partial W_2} &= \frac{\partial l}{\partial x_4} \cdot \frac{\partial x_4}{\partial x_3} \cdot \frac{\partial x_3}{\partial x_2} \cdot \frac{\partial x_2}{\partial W_2} = \Bigg[ [-2(y-x_4)] \cdot W_4 \cdot W_3 \Bigg] \cdot x_1 \\ \frac{\partial l}{\partial W_1} &= \frac{\partial l}{\partial x_4} \cdot \frac{\partial x_4}{\partial x_3} \cdot \frac{\partial x_3}{\partial x_2} \cdot \frac{\partial x_2}{\partial x_1} \cdot \frac{\partial x_1}{\partial W_1} = \Bigg[ [-2(y-x_4)] \cdot W_4 \cdot W_3 \cdot W_2 \Bigg] \cdot x \\ \end{split}

对上面这四个权重矩阵的链式求导公式找一下规律,可以发现对于权重矩阵 W_i 的梯度在计算时主要有两项:

  • 第一项是上述公式中使用特别大的中括号扩起来的部分,这部分是第 i+1 层反传回来的值,我们使用符号 l_{i+1} 来表示这一项;

  • 另一项则是第 i-1 层计算出来的中间值,使用符号 x_{i-1} 来表示;

那么对于 W_i 的梯度计算公式就变为了 \frac{\partial l}{\partial W_i} = l_{i+1} \cdot x_{i-1},这里的 l_{i+1} 是第 i+1 层反传过来的,所以计算第 i 层的梯度时只需要做一次矩阵乘法即可。这里的 x_{i-1} 正是在前向传播时计算出来的中间状态,比较官方的术语为 中间激活值

2、中间激活值显存分析#

2.1 tranformer 的结构#

这里把 transformer 层分为两部分,一部分是 MHA 层,一部分是 FFN 层。下面分别写一下这两部分的公式。一般的资料中关于 transformer 的公式仅写主要的部分,像dropout、normalize、激活函数都会被省略,但是这里由于需要分析中间激活值的显存,所以会把整个 transformer 的所有操作都体现到公式中,如下。

MHA 层的公式如下:

\begin{equation}\begin{split} Q &= x \cdot W_Q, \quad K = x \cdot W_k, \quad V = x \cdot W_v \\ x_{\text{self}} &= \text{Dropout}\Big[ \text{softmax}\big(\frac{Q \cdot K^T}{\sqrt{d}} \big) \Big] \cdot V \\ x_{\text{attn}} &= \text{LN}\Big[ \text{Dropout}\big(x_{\text{self}} \cdot w_o \big) + x \Big] \end{split}\end{equation}

FFN 层的公式如下:

\begin{equation}\begin{split} x_{\text{ffn}} &= \text{GeLU}(x_{\text{attn}} \cdot W_{\text{ff1}}) \cdot W_{\text{ff2}} \\ x_o &= \text{LN}\Big[\text{Dropout}\big(x_{\text{ffn}} \big) + x_{\text{attn}} \Big] \end{split}\end{equation}

总的来说,MHA 层的输入为 x,输出为 x_{\text{attn}};FFN 层的输入为 x_{\text{attn}},输出为 x_o

2.2 tranformer 的中间激活值分析#

首先面定义几个符号:

  • b:表示batch_size;
  • s:表示seq_length,为文本长度;
  • h:表示hidden_dim,为隐藏层的维度;
  • a:表示多头注意力中有多个头;
  • h_a:表示hidden_dim_per_head,为多头注意力中每个头的隐藏层维度;

另外,在实际使用时一般都有 h_a * a = h 成立。

MHA 层需要保存的激活值,以及每个激活值的大小:

\begin{alignat}{10} Q = x \cdot W_Q \quad &: \quad \text{维度为 } [b, a, s, h_a] = [b, s, h], &\text{大小为 } 2bsh \text{ 字节} \\ K = x \cdot W_k \quad &: \quad \text{维度为 } [b, a, s, h_a] = [b, s, h], &\text{大小为 } 2bsh \text{ 字节} \\ V = x \cdot W_v \quad &: \quad \text{维度为 } [b, a, s, h_a] = [b, s, h], &\text{大小为 } 2bsh \text{ 字节} \\ Q \cdot K^T \quad &: \quad \text{维度为 } [b, a, s, s], &\text{大小为 } 2bas^2 \text{ 字节} \\ \text{softmax}(\frac{Q^T K}{\sqrt{d}}) \quad &: \quad \text{维度为 } [b, a, s, s], &\text{大小为 } 2bas^2 \text{ 字节} \\ \text{Dropout}\Big[ \text{softmax}\big(\frac{Q \cdot K^T}{\sqrt{d}} \big) \Big] \quad &: \quad \text{维度为 } [b, a, s, s], &\text{Dropout 层大小为 } bas^2 \text{ 字节} \\ x_{\text{self}} = \text{Dropout}\Big[ \text{softmax}\big(\frac{Q \cdot K^T}{\sqrt{d}} \big) \Big] \cdot V \quad &: \quad \text{维度为 } [b, a, s, h_a] = [b, s, h], &\text{大小为 } 2bsh \text{ 字节} \\ x_{\text{self}} \cdot W_o \quad &: \quad \text{维度为 } [b, s, h], &\text{大小为 } 2bsh \text{ 字节} \\ \text{Dropout}\big(x_{\text{self} \cdot w_o} \big) \quad &: \quad \text{维度为 } [b, s, h], &\text{Dropout 层大小为 } bsh \text{ 字节} \\ x_{\text{attn}} = \text{LN}\Big[ \text{Dropout}\big(x_{\text{self}} \cdot w_o \big) + x \Big] \quad &: \quad \text{维度为 } [b, s, h], &\text{大小为 } 2bsh \text{ 字节} \end{alignat}

FFN 层需要保存的激活值,以及每个激活值的大小:

\begin{alignat}{2} x_{\text{attn}} \cdot W_{\text{ff1}} \quad &: \quad \text{维度为 } [b, s, 4h], &\text{大小为 } 8bsh \text{ 字节} \\ \text{GeLU} (x_{\text{attn}} \cdot W_{\text{ff1}}) \quad &: \quad \text{维度为 } [b, s, 4h], &\text{大小为 } 8bsh \text{ 字节} \\ x_{\text{ffn}} = \text{GeLU}(x_{\text{attn}} \cdot W_{\text{ff1}}) \cdot W_{\text{ff2}} \quad &: \quad \text{维度为 } [b, s, h], &\text{大小为 } 2bsh \text{ 字节} \\ \text{Dropout}\big(x_{\text{ffn}} \big) \quad &: \quad \text{维度为 } [b, s, h], \qquad &\text{Dropout 层大小为 } bsh \text{ 字节} \\ \text{LN}\Big[\text{Dropout}\big(x_{\text{ffn}} \big) + x_{\text{attn}} \Big] \quad &: \quad \text{维度为 } [b, s, h], &\text{大小为 } 2bsh \text{ 字节} \\ \end{alignat}

将 MHA 和 FFN 层全部加起来得到:

\begin{split} & 2bsh+2bsh+2bsh+2bas^2+2bas^2+bas^2+2bsh+2bsh+bsh+2bsh+ \\ & 8bsh+8bsh+2bsh+bsh+2bsh=34bsh+5bas^2 \end{split}

如果有 l 层 transformer,那么这 l 层 transformer 总的中间激活值占用的显存为:l * (34bsh+5bas^2)

2.3 embedding 层和解码层#

上面仅分析了多个 transformer 对应的中间激活值消耗的显存的大小。模型中还会有 embedding 层和解码层。其中解码层没有对应的中间激活值,只需要分析一下 embedding 层即可。

embedding 层的功能是将输入的 token ID 转为向量,其输出的矩阵维度为 [batch_size, seq_length, hidden_size],即 [b, s, h],该中间激活值占用的显存为 2bsh

综上所述,整个模型所有的中间激活值的大小为 l * (34bsh + 5bas^2) + 2bsh。随着模型越来越大,l 是比较大的,所以有时会忽略 2bsh 这一项,直接使用 l*(34bsh+5bas^2) 来估计模型的中间激活值的大小。

3、实例分析之bert-base#

以 bert-base 为例分析其中间激活值所需要的显存的大小。在文章 静态显存分析 中已经得出对于 bert-base 来说,模型参数、梯度、优化器状态三部分总共需要的显存为 1.76G,下面分析一下该模型的中间激活值所需要的显存。

计算的公式为 l * (34bsh + 5bas^2) + 2bsh,先来看这里的每个值分别是多少:

  • l 为 tranformer 的层数,bert-base 为 12 层;

  • b 为 batch_size,这里分别取 batch_size 为 16 和 1 计算两个不同 batch_size 下所需要显存的差别;

  • s 为 seq_length,模型 bert-base 的最大长度为 512;

  • h 为隐藏层的维度,为 768;

  • a 为多头注意力层的 head 个数,为 12;

当 batch_size 为 1 时,计算结果如下:

\begin{equation}\begin{split} &12 * (34 * 1 * 512 * 768 + 5 * 1 * 12 * 512 * 512) + 2 * 1 * 512 * 768 \\ = &12 * (13.4M + 15.7M) + 0.79M \\ = &12 * 29.1M + 0.79M \\ = &250M \end{split}\end{equation}

当 batch_size 为 16 时,计算结果如下:

\begin{equation}\begin{split} &12 * (34 * 16 * 512 * 768 + 5 * 16 * 12 * 512 * 512) + 2 * 16 * 512 * 768 \\ = &12 * (213M + 251M) + 12M \\ = &12 * 464M + 12M \\ = &5.58G \end{split}\end{equation}

从公式 l * (34bsh + 5bas^2) + 2bsh 中就可以看出,中间激活值消耗的显存的大小与 batch_size 是正相关的,在 bert-base 这个模型中,batch_size 每增加 1,那么中间激活值所需要的显存就会增加大约 250M 左右。

可以看到当 batch_size 为 16 时,中间激活值所需要的显存为 5.58G,而静态显存只有 1.76G,中间激活值比这部分静态显存还要大。

在实际训练时,遇到显存不足时,调小 batch_size 实际减少的就是中间激活值这部分所需要的显存。而模型参数、梯度、优化器这部分静态显存与 batch_size 是无关的,调小 batch_size 不会影响这里所需要的显存。

4、实例分析之LLAMA-65B#

以 LLAMA-65B 为例分析其中间激活值所需要的显存的大小。在文章 静态显存分析 中已经得出对于 LLAMA-65B 来说,模型参数、梯度、优化器状态三部分总共需要的显存为 1040G,下面分析一下该模型的中间激活值所需要的显存。

计算的公式为 l * (34bsh + 5bas^2) + 2bsh,先来看这里的每个值分别是多少:

  • l 为 tranformer 的层数,LLAMA-65B 为 80 层;

  • b 为 batch_size,这里分别取 batch_size 为 16 和 1 计算两个不同 batch_size 下所需要显存的差别;

  • s 为 seq_length,模型 LLAMA-65B 的最大长度为 2048;

  • h 为隐藏层的维度,为 8192;

  • a 为多头注意力层的 head 个数,为 64;

当 batch_size 为 1 时,计算结果如下:

\begin{equation}\begin{split} &80 * (34 * 1 * 2048 * 8192 + 5 * 1 * 64 * 2048 * 2048) + 2 * 1 * 2048 * 8192 \\ = &80 * (570M + 1342M) + 33M \\ = &80 * 1912M + 33M \\ = &153G \end{split}\end{equation}

当 batch_size 为 16 时,计算结果如下:

\begin{equation}\begin{split} &80 * (34 * 16 * 2048 * 8192 + 5 * 16 * 64 * 2048 * 2048) + 2 * 16 * 2048 * 8192 \\ = &80 * (9.13G + 21.47G) + 536M \\ = &80 * 30.6G + 536M \\ = &2448G \end{split}\end{equation}

从公式 l * (34bsh + 5bas^2) + 2bsh 中就可以看出,中间激活值消耗的显存的大小与 batch_size 是正相关的,在 LLAMA-65B 这个模型中,batch_size 每增加 1,那么中间激活值所需要的显存就会增加大约 153G 左右。

可以看到当 batch_size 为 16 时,中间激活值所需要的显存为 2448G,而静态显存只有 1040G,中间激活值比这部分静态显存还要大。

对于大模型,在实际训练时,现在基本都会使用梯度检查技术,这样中间激活值这里需要的显存就会大幅减少,所以实际训练时并不会真实消耗这么多的显存。

总结#

本文介绍了在训练阶段中,哪一部分是中间激活值;以及中间激活值的大小如何估算;最后以模型 bert-base 和 LLAMA-65B 为例计算了其训练时中间激活值所需的显存。

Reference#