[隐藏左侧目录栏][显示左侧目录栏]
中间激活值显存分析
中间激活值(intermediate activations):是在前向传播的过程中,为了让后向传播完成计算,所需要保留的模型中间结果。
1、什么是中间激活值?
以一个四个 Linear 的模型结构为例进行说明。其前向传播和损失函数的公式如下所示:
\begin{split}
x_1 &= W_1 x + b_1 \\
x_2 &= W_2 x_1 + b_2 \\
x_3 &= W_3 x_2 + b_3 \\
x_4 &= W_4 x_3 + b_4 \\
l &= (y - x_4)^2
\end{split}
在该公式中:x 和 y 为数据的特征和标签;W_1、b_1、W_2、b_2、W_3、b_3、W_4、b_4 为四个 Linear 层的权重和偏置;x_1、x_2、x_3、x_4 都是计算过程中的中间状态。
反向传播过程中要对权重进行更新,也就是求损失相对于 W_1、W_2、W_3、W_4 的偏导,按照链式求导法则得到公式如下:
\begin{split}
\frac{\partial l}{\partial W_4} &= \frac{\partial l}{\partial x_4} \cdot \frac{\partial x_4}{\partial W_4} = \Bigg[ -2(y-x_4)\Bigg] \cdot x_3 \\
\frac{\partial l}{\partial W_3} &= \frac{\partial l}{\partial x_4} \cdot \frac{\partial x_4}{\partial x_3} \cdot \frac{\partial x_3}{\partial W_3} = \Bigg[ [-2(y-x_4)] \cdot W_4 \Bigg] \cdot x_2 \\
\frac{\partial l}{\partial W_2} &= \frac{\partial l}{\partial x_4} \cdot \frac{\partial x_4}{\partial x_3} \cdot \frac{\partial x_3}{\partial x_2} \cdot \frac{\partial x_2}{\partial W_2} = \Bigg[ [-2(y-x_4)] \cdot W_4 \cdot W_3 \Bigg] \cdot x_1 \\
\frac{\partial l}{\partial W_1} &= \frac{\partial l}{\partial x_4} \cdot \frac{\partial x_4}{\partial x_3} \cdot \frac{\partial x_3}{\partial x_2} \cdot \frac{\partial x_2}{\partial x_1} \cdot \frac{\partial x_1}{\partial W_1} = \Bigg[ [-2(y-x_4)] \cdot W_4 \cdot W_3 \cdot W_2 \Bigg] \cdot x \\
\end{split}
对上面这四个权重矩阵的链式求导公式找一下规律,可以发现对于权重矩阵 W_i 的梯度在计算时主要有两项:
那么对于 W_i 的梯度计算公式就变为了 \frac{\partial l}{\partial W_i} = l_{i+1} \cdot x_{i-1},这里的 l_{i+1} 是第 i+1 层反传过来的,所以计算第 i 层的梯度时只需要做一次矩阵乘法即可。这里的 x_{i-1} 正是在前向传播时计算出来的中间状态,比较官方的术语为 中间激活值。
2、中间激活值显存分析
这里把 transformer 层分为两部分,一部分是 MHA 层,一部分是 FFN 层。下面分别写一下这两部分的公式。一般的资料中关于 transformer 的公式仅写主要的部分,像dropout、normalize、激活函数都会被省略,但是这里由于需要分析中间激活值的显存,所以会把整个 transformer 的所有操作都体现到公式中,如下。
MHA 层的公式如下:
\begin{equation}\begin{split}
Q &= x \cdot W_Q, \quad K = x \cdot W_k, \quad V = x \cdot W_v \\
x_{\text{self}} &= \text{Dropout}\Big[ \text{softmax}\big(\frac{Q \cdot K^T}{\sqrt{d}} \big) \Big] \cdot V \\
x_{\text{attn}} &= \text{LN}\Big[ \text{Dropout}\big(x_{\text{self}} \cdot w_o \big) + x \Big]
\end{split}\end{equation}
FFN 层的公式如下:
\begin{equation}\begin{split}
x_{\text{ffn}} &= \text{GeLU}(x_{\text{attn}} \cdot W_{\text{ff1}}) \cdot W_{\text{ff2}} \\
x_o &= \text{LN}\Big[\text{Dropout}\big(x_{\text{ffn}} \big) + x_{\text{attn}} \Big]
\end{split}\end{equation}
总的来说,MHA 层的输入为 x,输出为 x_{\text{attn}};FFN 层的输入为 x_{\text{attn}},输出为 x_o;
首先面定义几个符号:
- b:表示batch_size;
- s:表示seq_length,为文本长度;
- h:表示hidden_dim,为隐藏层的维度;
- a:表示多头注意力中有多个头;
- h_a:表示hidden_dim_per_head,为多头注意力中每个头的隐藏层维度;
另外,在实际使用时一般都有 h_a * a = h 成立。
MHA 层需要保存的激活值,以及每个激活值的大小:
\begin{alignat}{10}
Q = x \cdot W_Q \quad &: \quad \text{维度为 } [b, a, s, h_a] = [b, s, h], &\text{大小为 } 2bsh \text{ 字节} \\
K = x \cdot W_k \quad &: \quad \text{维度为 } [b, a, s, h_a] = [b, s, h], &\text{大小为 } 2bsh \text{ 字节} \\
V = x \cdot W_v \quad &: \quad \text{维度为 } [b, a, s, h_a] = [b, s, h], &\text{大小为 } 2bsh \text{ 字节} \\
Q \cdot K^T \quad &: \quad \text{维度为 } [b, a, s, s], &\text{大小为 } 2bas^2 \text{ 字节} \\
\text{softmax}(\frac{Q^T K}{\sqrt{d}}) \quad &: \quad \text{维度为 } [b, a, s, s], &\text{大小为 } 2bas^2 \text{ 字节} \\
\text{Dropout}\Big[ \text{softmax}\big(\frac{Q \cdot K^T}{\sqrt{d}} \big) \Big] \quad &: \quad \text{维度为 } [b, a, s, s], &\text{Dropout 层大小为 } bas^2 \text{ 字节} \\
x_{\text{self}} = \text{Dropout}\Big[ \text{softmax}\big(\frac{Q \cdot K^T}{\sqrt{d}} \big) \Big] \cdot V \quad &: \quad \text{维度为 } [b, a, s, h_a] = [b, s, h], &\text{大小为 } 2bsh \text{ 字节} \\
x_{\text{self}} \cdot W_o \quad &: \quad \text{维度为 } [b, s, h], &\text{大小为 } 2bsh \text{ 字节} \\
\text{Dropout}\big(x_{\text{self} \cdot w_o} \big) \quad &: \quad \text{维度为 } [b, s, h], &\text{Dropout 层大小为 } bsh \text{ 字节} \\
x_{\text{attn}} = \text{LN}\Big[ \text{Dropout}\big(x_{\text{self}} \cdot w_o \big) + x \Big] \quad &: \quad \text{维度为 } [b, s, h], &\text{大小为 } 2bsh \text{ 字节}
\end{alignat}
FFN 层需要保存的激活值,以及每个激活值的大小:
\begin{alignat}{2}
x_{\text{attn}} \cdot W_{\text{ff1}} \quad &: \quad \text{维度为 } [b, s, 4h], &\text{大小为 } 8bsh \text{ 字节} \\
\text{GeLU} (x_{\text{attn}} \cdot W_{\text{ff1}}) \quad &: \quad \text{维度为 } [b, s, 4h], &\text{大小为 } 8bsh \text{ 字节} \\
x_{\text{ffn}} = \text{GeLU}(x_{\text{attn}} \cdot W_{\text{ff1}}) \cdot W_{\text{ff2}} \quad &: \quad \text{维度为 } [b, s, h], &\text{大小为 } 2bsh \text{ 字节} \\
\text{Dropout}\big(x_{\text{ffn}} \big) \quad &: \quad \text{维度为 } [b, s, h], \qquad &\text{Dropout 层大小为 } bsh \text{ 字节} \\
\text{LN}\Big[\text{Dropout}\big(x_{\text{ffn}} \big) + x_{\text{attn}} \Big] \quad &: \quad \text{维度为 } [b, s, h], &\text{大小为 } 2bsh \text{ 字节} \\
\end{alignat}
将 MHA 和 FFN 层全部加起来得到:
\begin{split}
& 2bsh+2bsh+2bsh+2bas^2+2bas^2+bas^2+2bsh+2bsh+bsh+2bsh+ \\
& 8bsh+8bsh+2bsh+bsh+2bsh=34bsh+5bas^2
\end{split}
如果有 l 层 transformer,那么这 l 层 transformer 总的中间激活值占用的显存为:l * (34bsh+5bas^2)
2.3 embedding 层和解码层
上面仅分析了多个 transformer 对应的中间激活值消耗的显存的大小。模型中还会有 embedding 层和解码层。其中解码层没有对应的中间激活值,只需要分析一下 embedding 层即可。
embedding 层的功能是将输入的 token ID 转为向量,其输出的矩阵维度为 [batch_size, seq_length, hidden_size],即 [b, s, h],该中间激活值占用的显存为 2bsh。
综上所述,整个模型所有的中间激活值的大小为 l * (34bsh + 5bas^2) + 2bsh。随着模型越来越大,l 是比较大的,所以有时会忽略 2bsh 这一项,直接使用 l*(34bsh+5bas^2) 来估计模型的中间激活值的大小。
3、实例分析之bert-base
以 bert-base 为例分析其中间激活值所需要的显存的大小。在文章 静态显存分析 中已经得出对于 bert-base 来说,模型参数、梯度、优化器状态三部分总共需要的显存为 1.76G,下面分析一下该模型的中间激活值所需要的显存。
计算的公式为 l * (34bsh + 5bas^2) + 2bsh,先来看这里的每个值分别是多少:
-
l 为 tranformer 的层数,bert-base 为 12 层;
-
b 为 batch_size,这里分别取 batch_size 为 16 和 1 计算两个不同 batch_size 下所需要显存的差别;
-
s 为 seq_length,模型 bert-base 的最大长度为 512;
-
h 为隐藏层的维度,为 768;
-
a 为多头注意力层的 head 个数,为 12;
当 batch_size 为 1 时,计算结果如下:
\begin{equation}\begin{split}
&12 * (34 * 1 * 512 * 768 + 5 * 1 * 12 * 512 * 512) + 2 * 1 * 512 * 768 \\
= &12 * (13.4M + 15.7M) + 0.79M \\
= &12 * 29.1M + 0.79M \\
= &250M
\end{split}\end{equation}
当 batch_size 为 16 时,计算结果如下:
\begin{equation}\begin{split}
&12 * (34 * 16 * 512 * 768 + 5 * 16 * 12 * 512 * 512) + 2 * 16 * 512 * 768 \\
= &12 * (213M + 251M) + 12M \\
= &12 * 464M + 12M \\
= &5.58G
\end{split}\end{equation}
从公式 l * (34bsh + 5bas^2) + 2bsh 中就可以看出,中间激活值消耗的显存的大小与 batch_size 是正相关的,在 bert-base 这个模型中,batch_size 每增加 1,那么中间激活值所需要的显存就会增加大约 250M 左右。
可以看到当 batch_size 为 16 时,中间激活值所需要的显存为 5.58G,而静态显存只有 1.76G,中间激活值比这部分静态显存还要大。
在实际训练时,遇到显存不足时,调小 batch_size 实际减少的就是中间激活值这部分所需要的显存。而模型参数、梯度、优化器这部分静态显存与 batch_size 是无关的,调小 batch_size 不会影响这里所需要的显存。
4、实例分析之LLAMA-65B
以 LLAMA-65B 为例分析其中间激活值所需要的显存的大小。在文章 静态显存分析 中已经得出对于 LLAMA-65B 来说,模型参数、梯度、优化器状态三部分总共需要的显存为 1040G,下面分析一下该模型的中间激活值所需要的显存。
计算的公式为 l * (34bsh + 5bas^2) + 2bsh,先来看这里的每个值分别是多少:
-
l 为 tranformer 的层数,LLAMA-65B 为 80 层;
-
b 为 batch_size,这里分别取 batch_size 为 16 和 1 计算两个不同 batch_size 下所需要显存的差别;
-
s 为 seq_length,模型 LLAMA-65B 的最大长度为 2048;
-
h 为隐藏层的维度,为 8192;
-
a 为多头注意力层的 head 个数,为 64;
当 batch_size 为 1 时,计算结果如下:
\begin{equation}\begin{split}
&80 * (34 * 1 * 2048 * 8192 + 5 * 1 * 64 * 2048 * 2048) + 2 * 1 * 2048 * 8192 \\
= &80 * (570M + 1342M) + 33M \\
= &80 * 1912M + 33M \\
= &153G
\end{split}\end{equation}
当 batch_size 为 16 时,计算结果如下:
\begin{equation}\begin{split}
&80 * (34 * 16 * 2048 * 8192 + 5 * 16 * 64 * 2048 * 2048) + 2 * 16 * 2048 * 8192 \\
= &80 * (9.13G + 21.47G) + 536M \\
= &80 * 30.6G + 536M \\
= &2448G
\end{split}\end{equation}
从公式 l * (34bsh + 5bas^2) + 2bsh 中就可以看出,中间激活值消耗的显存的大小与 batch_size 是正相关的,在 LLAMA-65B 这个模型中,batch_size 每增加 1,那么中间激活值所需要的显存就会增加大约 153G 左右。
可以看到当 batch_size 为 16 时,中间激活值所需要的显存为 2448G,而静态显存只有 1040G,中间激活值比这部分静态显存还要大。
对于大模型,在实际训练时,现在基本都会使用梯度检查技术,这样中间激活值这里需要的显存就会大幅减少,所以实际训练时并不会真实消耗这么多的显存。
总结
本文介绍了在训练阶段中,哪一部分是中间激活值;以及中间激活值的大小如何估算;最后以模型 bert-base 和 LLAMA-65B 为例计算了其训练时中间激活值所需的显存。
Reference