静态显存分析#
首先要明确训练时使用的显存由哪几部分构成,这里把训练时使用的显存分为如下几部分:
-
Pytorch 上下文:在训练时显卡中的第一份显存肯定要先分配给初始化环境的框架,比如 Pytorch,这部分显存的大小与显卡的型号、Pytorch 的版本号都有关系,想要确定的话直接测试一下即可。
-
静态显存:模型参数消耗的显存,梯度消耗的显存,优化器消耗的显存;
-
中间激活状态消耗的显存;
-
临时缓冲区占用的显存,以及其他零散的显存;
在上述四部分显存中,第一部分显存可以直接进行测试。第二部分的显存是本文的分析重点。第三部分的显存在下一篇文章中间激活值显存分析中进行分析。第四部分显存无法直接分析,同时它相比于前面部分所占用的显存要小很多。
本文主要分析模型参数占用的显存大小,梯度占用的显存大小,以及优化器占用的显存大小。
1、单精度(fp32)训练#
假设模型的参数量为 \Phi。
模型的参数消耗的显存:所有的模型参数需要存储到显存中,使用 fp32 存储,则需要消耗的显存为 4\Phi;
梯度消耗的显存:梯度和模型参数的量是完全相同的,使用 fp32 存储,则需要消耗的显存为 4\Phi;
优化器消耗的显存:以 Adam 为例,使用单精度(fp32)进行训练,Adam 中会存储 averaged momentum 和 variance 两部分,都和模型的参数量是相同的,即 \Phi。所以优化器消耗的显存为 2 * 4\Phi。
这样,模型的参数消耗的显存、梯度消耗的显存、优化器消耗的显存分别为:4\Phi、4\Phi、4\Phi * 2,总计为 16\Phi。以 1.5B 的 GPT-2 为例,总大小为 1.5B * (4 + 4 + 4 * 2) = 24G
2、混合精度训练#
假设模型的参数量为 \Phi。
模型的参数消耗的显存:所有的模型参数需要存储到显存中,使用 fp16 存储,则需要的显存为 2\Phi;
梯度消耗的显存:梯度和模型参数的量是完全相同的,使用 fp16 存储,则需要的显存为 2\Phi;
优化器消耗的显存:以 Adam 为例,使用混合精度进行训练时,Adam 中会使用 fp32 存储一份模型参数的备份,并且使用 fp32 存储 averaged momentum 和 variance 两部分。所以优化器需要消耗的显存为 3 * 4\Phi。
这样,模型的参数消耗的显存、梯度消耗的显存、优化器消耗的显存分别为:2\Phi、2\Phi、4\Phi * 3,总计为 16\Phi。以 1.5B 的 GPT-2 为例,总大小为 1.5B * (2 + 2 + 4 * 3) = 24G
也就是说,对于参数、梯度、优化器三者消耗的显存之和,在使用单精度和混合精度时是相同的。
3、实例分析之bert-base#
bert-base 的参数量为 110M。假设使用 Adam 优化器和混合精度训练,按照上述结论分析一下 bert-base 模型需要的显存。
模型参数 110M * 2 = 220M
梯度 110M * 2 = 220M
优化器 110M * (4 + 4 + 4) = 1.32G
总计需要 220M + 220M + 1.32G = 1.76G
4、实例分析之LLAMA-65B#
LLAMA-65B 的参数量为 65B。假设使用 Adam 优化器和混合精度训练,按照上述结论分析一下 LLAMA-65B 模型需要的显存。
模型参数 65B * 2 = 130G
梯度 65B * 2 = 130G
优化器 65B * (4 + 4 + 4) = 780G
总计需要 130 + 130 + 780 = 1040G
总结#
本文分析了在单精度和混合精度下,模型参数、梯度、优化器状态这三部分静态显存如何计算。并以 bert-base 和 llama-65B 为例估算其所需的静态显存的大小。