计算量分析#
为了方便描述,首先定义各符号的含义:
- b:表示batch_size;
- s:表示seq_length,为文本长度;
- h:表示hidden_dim,为隐藏层的维度;
- a:表示多头注意力中有多个头;
- h_a:表示hidden_dim_per_head,为多头注意力中每个头的隐藏层维度;
另外,在实际使用时一般都有 h_a * a = h 成立。
1、矩阵相乘的计算量#
通过一个例子来说明两个矩阵相乘时所需要的计算量。现在有矩阵 A 和矩阵 B,这两个矩阵如下所示:
将这两个矩阵相乘之后得到如下结果:
可以看出矩阵 A 的维度为 [2,3],矩阵 B 的维度为 [3,4],相乘结果的矩阵维度为 [2,4]。结果矩阵中每个元素中有 3 次乘法运算和 (3-1=2) 次加法运算。所以矩阵 A 乘以矩阵 B 总共有 2*4*3 次乘法运算,总共有 2*4*(3-1) 次加法运算。
假设矩阵 A 的维度为 [m,n],假设矩阵 B 的维度为 [n,p],那么两个矩阵相乘 AB 总共有 m \cdot p \cdot n 次乘法运算,总共有 m \cdot p \cdot (n-1) 次加法运算,由于 n 一般远大于1,所以会把减一部分忽略掉。即有如下成立:
对于 A \in R^{m*n},B \in R^{n*p},计算出 AB 所需要的运算次数为 2mnp
这里所提到的每次运算就是浮点运算。
2、transformer结构中计算方式分类#
为了估计计算量,transformer 结构中总共有三种类型的计算:
-
第一种是模型中的 Linear 层,该计算是矩阵乘法运算,即 x \cdot W,此时计算量与 x 和 W 有关。
-
第二种是逐个元素的计算方式,比如softmax、dropout、normalize、残差连接、激活函数等。
-
第三种是注意力权重矩阵部分的计算,即 Q \cdot K^T,该计算也是矩阵乘法运算,但是参与相乘的两个矩阵都是中间激活值,与模型的参数量无关。
下面先逐类型阐述上述每种计算类型的计算量如何计算,然后针对 transformer 结构的计算量做整体估计。
3、Linear的计算量#
假设 Linear 层的输入维度为 [b,s,h],Linear 的权重参数维度为 [h,h],前向传播的公式为:
维度变化为:
所以,对于 Linear 模型结构,一次前向传播的计算量为 2 \cdot b \cdot s \cdot h \cdot h=2bsh^2
在文章 反向传播实现说明 中已经说明过,Linear 的反向传播求解梯度的过程为两次矩阵乘法,所以一次反向传播中的计算量是一次前向传播的计算量的两倍,即:4bsh^2
如果是一次前向传播和一次反向传播完成训练的话,那么 Linear 层的总计算量为 6bsh^2。如果使用了梯度检查技术,那么还需要再增加一次前向传播才能完成训练,那么总计算量为 8bsh^2。
接下来将计算量与 token 数量和模型的参数量做一下关联。
输入矩阵 x 的维度是 [b,s,h],其对应的 token 数量为 b \cdot s;权重矩阵 W 的维度是 [h,h],其对应的参数量为 h \cdot h。直接将 token 的数量与参数量相乘可得 bsh^2,可以看出该结果刚好等于计算量中除去系数的部分。
所以对于 Linear 层部分的网络的计算量可以直接使用如下公式进行估计:
其中 \text{token} 表示参与训练的 token 的数量,\text{param} 表示模型的参数量。
注意:上述公式仅用于估计 Linear 网络结构的计算量,其他网络结构并不适用。
4、逐元素运算符的计算量#
在深度神经网络中有两个矩阵相乘的操作,也有针对矩阵中每个元素进行运算的操作,比如softmax、dropout、normalize、残差连接、激活函数等。这些逐元素的运算不涉及两个矩阵,而是对单个矩阵中的每个元素进行运算。以 sigmoid 为例来看。
比如 sigmoid 公式为:
sigmoid 的导数公式为:
假设输入矩阵 x 的维度为 [b,s,h],那么无论是前向传播,还是反向传播求梯度,这类操作的计算量都为 C \cdot bsh,这里的 C 是一个常数。
上一小节已经说明了对于两个矩阵相乘的计算量为 2bsh^2,可以看出逐元素运算的计算量要比两个矩阵相乘的计算两少了一个维度 h,一般来说也就差了三到四个数量级,比如 LLAMA-65B 的隐藏层维度为 8192。所以在估计整个网络的计算量时,所有的逐元素运算符的计算量可以忽略掉。
5、注意力矩阵的计算量#
注意力矩阵部分的计算公式如下所示:
对于 Q \cdot K^T,矩阵 Q 和矩阵 K 的维度都是 [b, a, s, h_a],计算时的维度变化为:
计算量为 2 \cdot bash_as = 2bs^2h
除以 \sqrt{d} 和 \text{softmax}() 操作都是逐元素运算,这里直接忽略。直接看乘以矩阵 V 的维度变化:
计算量为 2 \cdot bas^2h_a=2bs^2h
所以注意力矩阵这部分网络结构,一次前向传播计算量为 4bs^2h。反向传播的计算量为前向传播的两倍,那么反向传播的计算量为 8bs^2h。
6、transformer 结构分析#
有了第3、4、5小节的分析之后,汇总来估计一个完整的 transformer 结构的计算量。这里先只分析前向传播的计算量,因为反向传播的计算量只要乘以两倍就可以得到。
transformer 层的公式如下所示:
在估计计算量时,像 softmax、dropout、normalize、残差连接、激活函数等结构都是针对矩阵中的某个元素做的运算。模型中这些结构的运算量相比于矩阵相乘要低几个数量级,所以在估计计算量时忽略这部分结构的计算量。于是,简化后的公式如下所示(下述简化只是为了方便估算计算量):
在后续估计计算量时,上述公式中的 \text{softmax}(\cdot) 和残差连接的计算量是被忽略的。
对于下述三个公式,维度变化都是 [b,s,h] * [h,h] \rightarrow [b,s,h],按照第3小节的分析,这三个公式总计算量为 3 * 2bsh^2=6bsh^2
对于下述公式,在第4小节已经分析过了,其计算量为 4bs^2h
对于下述公式,维度变化为 [b,s,h] * [h,h] \rightarrow [b,s,h],按照第3小节的分析,计算量为 2bsh^2
对于下述公式,先乘矩阵 W_{ff1},再乘矩阵 w_{ff2},维度变化为 [b,s,h] * [h,4h] \rightarrow [b,s,4h]、[b,s,4h] * [4h,h] \rightarrow [b,s,h],按照第3小节的分析,总计算量为 2 * bsh * 4h + 2 * bs * 4h * h=16bsh^2
将上述四部分加起来得到 transformer 结构的总计算量为 6bsh^2+4bs^2h+2bsh^2+16bsh^2=24bsh^2+4bs^2h。
综上,transformer 结构一次前向传播计算量为 24bsh^2+4bs^2h,反向传播的计算量为其两倍,再考虑上梯度检查,最终完成训练所需的总计算量为 4 *(24bsh^2+4bs^2h)。
如果只是为了粗略估计的话,一般会把上述结果公式中的 4bs^2h 部分省略掉,仅保留 4 * 24bsh^2 这部分。我们知道 token 的数量为 bs,并且在文章 参数量分析 中分析得出 transformer 的参数量为 12h^2,所以关于计算量有如下式成立:
这也就是第3小节提到的公式。
关于公式 4 *(24bsh^2+4bs^2h),中的 4bs^2h 是否是远小于 24bsh^2,这个其实并不是一定满足的,只是省略掉该项之后会很容易估算。