KVCache#
关于该技术知乎的 Young 的回答 非常清晰易懂,推荐去看一看。本文的中文描述基本都是来自该回答。
1、KVCache是啥?#
大模型推理性能优化的一个常用技术是KVCache,该技术可以在不影响任何计算精度的前提下,通过空间换时间思想,提高推理性能。
2、背景#
生成式generative模型的推理过程很有特点,我们给一个输入文本,模型会输出一个回答(长度为N),其实该过程中执行了N次推理过程。即GPT类模型一次推理只输出一个token,输出token会与输入tokens 拼接在一起,然后作为下一次推理的输入,这样不断反复直到遇到终止符。
3、原理#
在上面的推理过程中,每 step 内,输入一个 token 序列,经过 Embedding 层将输入 token 序列变为一个三维张量[bs, seq_len, embed_dim],经过一通计算,最后经 logits 层将计算结果映射至词表空间,输出张量维度为[bs, seq_len, vocab_size]。
当前轮输出 token 与输入 tokens 拼接,并作为下一轮的输入 tokens,反复多次。可以看出第 i+1 轮输入数据只比第 i 轮输入数据新增了一个token,其他全部相同。因此第 i+1 轮推理时必然包含了第 i 轮的部分计算。KVCache 的出发点就在这里,缓存当前轮可重复利用的计算结果,下一轮计算时直接读取缓存结果。
4、调用层面实现细节#
目前各大模型推理都实现了 KVCache,下面就看如何使用了。增加了 KVCache 之后主要改动:
-
在推理时新增了 past_key_values 参数,该参数就会以追加方式保存每一轮的 (k,v) 值。past_key_values 变量内容为((k,v), (k,v), ..., (k,v)),即有 n_{\text{layers}} 个 (k,v) 组成的一个元组,其中 k 和 v 的维度均为 [bs, num_heads, seq_len, head_dims]。这里可以顺带计算出每轮推理对应的 cache 数据量为 2*\text{bs}*\text{seq_len}*\text{num_heads}*\text{head_dims}*n_{\text{layers}}。以GPT3-175B为例,假设以 float16 来保存 KVCache,senquence长度为100,batch_size=1,则 KVCache占用显存为 (2×1×100×12288×96)×2Byte= 472MB。
-
推理输出的 token 直接作为下一轮的输入,不再拼接,因为上文信息已经在 KVCache 中。
如果使用像 huggingface 的 transformers 库中提供的 model.generate()
函数,那么基本什么都不用做就启用了 KVCache,这里是为了了解其原理,所以不调用该函数,而是使用 model 的 forward() 函数。下面是在推理时不使用 KVCache 和使用 KVCache 的代码对比:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
model_name_or_path = "/path/for/gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name_or_path, torchscript=True).eval()
tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path)
# 不使用 KVCache 时的推理代码
in_text = "Lionel Messi is a"
in_tokens = torch.tensor(tokenizer.encode(in_text))
token_eos = torch.tensor([198]) # 终止符,遇到该字符就结束推理
out_token = None
with torch.no_grad():
while out_token != token_eos:
logits, _ = model(in_tokens)
out_token = torch.argmax(logits[-1, :], dim=0, keepdim=True)
in_tokens = torch.cat((in_tokens, out_token), 0) # 每次都把之前的所有token与推理得到的新token拼接起来作为下次的输入
text = tokenizer.decode(in_tokens)
out_text = tokenizer.decode(in_tokens)
print(out_text)
# 使用 KVCache 时的推理代码
in_text = "Lionel Messi is a"
in_tokens = torch.tensor(tokenizer.encode(in_text))
token_eos = torch.tensor([198]) # 终止符,遇到该字符就结束推理
out_token = None
kvcache = None
out_text = in_text
with torch.no_grad():
while out_token != token_eos:
logits, kvcache = model(in_tokens, past_key_values=kvcache) # 这里需要传入上次推理缓存的(k,v)
out_token = torch.argmax(logits[-1, :], dim=0, keepdim=True)
in_tokens = out_token # 每次输出的out_token直接作为下一次的输入,这个不需要再和之前的in_tokens做拼接了
text = tokenizer.decode(in_tokens)
out_text += text
print(out_text)
5、底层实现细节#
KV Cache 配置开启后,推理过程可以分为2个阶段:
-
预填充阶段:发生在计算第一个输出token过程中,这时Cache是空的,计算时需要为每个 transformer layer 计算并保存 key cache 和 value cache,在输出 token 时 Cache 完成填充。
-
使用KV Cache阶段:发生在计算第二个输出token至最后一个token过程中,这时Cache是有值的,每轮推理只需读取Cache,同时将当前轮计算出的新的 Key、Value 追加写入至Cache。
下面是以GPT2模型为例,展示了如果在原始的GPT2模型上添加 KVCache 功能需要修改哪些代码。
如果对原始的 GPT2 模型结构和代码不了解,请参考 GPT2 模型结构和实现代码,本文中 GPT2 的原始代码就是取自该文。下述代码中注释已经足够清晰。同样的,该代码仅是用于学习理解,其在语法上有不少的bug,也是跑不通的;同时该代码不一定是最优雅的实现,而是尽量选择容易理解的实现方式。
阅读并理解该代码时可以思考以下问题:
- KVCache 减少了 multi-head attention 层中哪部分运算?
- 使用了 KVCache 之后,MLP层(也称FFN层)的运算有没有减少?
- 为什么仅缓存 key 和 value,不缓存 query?
import torch
from torch import nn
class GTP2Attention(nn.Module):
""" 该类为一个完整的 multi-head attention 结构 """
def __init__(self, config):
super().__init__()
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
self.c_attn = nn.Linear(self.embed_dim, 3 * self.embed_dim)
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
- def forward(self, hidden_states, attention_mask=None):
+ def forward(self, hidden_states, attention_mask=None, layer_past=None, use_cache=False):
# 如果是预填充阶段,那么hidden_states的维度是: [bs, seq_len, embed_dim]
# 如果是使用KVCache阶段,那么hidden_states的维度是: [bs, 1, embed_dim]
# 将 W_q, W_k, W_v 三个矩阵合成一次矩阵乘法运算,乘完之后再分割开来
# 如果是预填充阶段,那么query, key, value的维度都是: [bs, seq_len, embed_dim]
# 如果是使用KVCache阶段,那么query, key, value的维度都是: [bs, 1, embed_dim]
query, key, value = self.c_attn(hidden_states).split(self.embed_dim, dim=2)
# 如果是预填充阶段,那么query, key, value的维度变化为:
# [bs, seq_len, embed_dim] -> [bs, seq_len, num_heads, head_dim] -> [bs, num_heads, seq_len, head_dim]
# 如果是使用KVCache阶段,那么query, key, value的维度变化为:
# [bs, 1, embed_dim] -> [bs, 1, num_heads, head_dim] -> [bs, num_heads, 1, head_dim]
new_size = query.size()[:-1] + [self.num_heads, self.head_dim]
query, key, value = query.view(new_size), key.view(new_size), value.view(new_size)
query, key, value = query.permute(0, 2, 1, 3), key.permute(0, 2, 1, 3), value.permute(0, 2, 1, 3)
+ if layer_past is not None:
+ past_key, past_value = layer_past
+
+ # past_key: [bs, num_heads, past_seq_len, head_dim]
+ # key: [bs, num_heads, 1, head_dim] -> [bs, num_heads, seq_len, head_dim],其中 seq_len = past_seq_len + 1
+ key = torch.cat((past_key, key), dim=-2)
+
+ # past_value: [bs, num_heads, past_seq_len, head_dim]
+ # value: [bs, num_heads, 1, head_dim] -> [bs, num_heads, seq_len, head_dim],其中 seq_len = past_seq_len + 1
+ value = torch.cat((past_value, value), dim=-2)
+ present = None
+ if use_cache:
+ present = (key, value)
# 如果是预填充阶段,那么注意力权重矩阵att_weights的维度为: [bs, num_heads, seq_len, seq_len]
# 如果是使用KVCache阶段,那么注意力权重矩阵att_weights的维度为: [bs, num_heads, 1, seq_len]
att_weights = torch.matmul(query, key.transpose(-1, -2))
# 对注意力矩阵除上一个根号d_k,然后再做softmax
att_weights = att_weights / torch.full([], value.size()[-1] ** 0.5)
if attention_mask is not None:
att_weights = att_weights + attention_mask
att_weights = nn.functional.softmax(att_weights, dim=-1)
att_weights = self.attn_dropout(att_weights)
# 如果是预填充阶段,那么att_output的维度为:
# [bs, num_heads, seq_len, head_dim] = [bs, num_heads, seq_len, seq_len] * [bs, num_heads, seq_len, head_dim]
# 如果是使用KVCache阶段,那么att_output的维度为:
# [bs, num_heads, 1, head_dim] = [bs, num_heads, 1, seq_len] * [bs, num_heads, seq_len, head_dim]
att_output = torch.matmul(att_weights, value)
# 如果是预填充阶段,那么维度变化是:
# [bs, num_heads, seq_len, head_dim] -> [bs, seq_len, num_heads, head_dim] -> [bs, seq_len, embed_dim]
# 如果是使用KVCache阶段,那么维度变化是:
# [bs, num_heads, 1, head_dim] -> [bs, 1, num_heads, head_dim] -> [bs, 1, embed_dim]
att_output = att_output.permute(0, 2, 1, 3)
att_output = att_output.view(att_output.size()[:-2] + [self.num_heads * self.head_dim, ])
#
att_output = self.c_proj(att_output)
att_output = self.resid_dropout(att_output)
- return att_output
+ if use_caceh:
+ return (att_output, present)
+ else:
+ return att_output
class GPT2MLP(nn.Module):
""" 该类为一个完整的 FFN 结构 """
def __init__(self, intermediate_size, config):
super().__init__()
self.embed_dim = config.hidden_size
# 一般情况下 intermediate_size = 4 * self.embed_dim,所以这里是一个先升维,后降维的过程
self.c_fc = nn.Linear(self.embed_dim, intermediate_size)
self.c_proj = nn.Linear(intermediate_size, self.embed_dim)
self.act = ... # TODO 这里初始化一个激活函数,比如 ReLU、GELU 等
self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, hidden_states, ):
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class GPT2Block(nn.Module):
""" 该类为一个完整的 transformer 结构 """
def __init__(self, config):
super().__init__()
hidden_size = config.hidden_size
inner_dim = 4 * hidden_size
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.att = GTP2Attention(config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, config)
- def forward(self, hidden_states, attention_mask=None):
+ def forward(self, hidden_states, attention_mask=None, layer_past=None, use_cache=False):
# 如果是预填充阶段,那么 hidden_states 维度为: [bs, seq_len, embed_dim]
# 如果是使用KVCache阶段,那么 hidden_states 维度为: [bs, 1, embed_dim]
# 相应的,在该函数中下面的变量 residual, attn_outputs, feed_forward_hidden_states 维度都与 hidden_states 是相同的
# ------------------------------------------------------
# multi-head attention 部分
# ------------------------------------------------------
# 保存一下 multi-head attention 层的输入值,等会用于残差连接
residual = hidden_states
# 前置(pred)layer norm
hidden_states = self.ln_1(hidden_states)
- attn_outputs = self.att(hidden_states, attention_mask)
+ if use_cache:
+ attn_outputs, present = self.att(hidden_states, attention_mask, layer_past)
+ else:
+ attn_outputs = self.att(hidden_states, sttention_mask, layer_past)
# 残差连接
hidden_states = attn_outputs + residual
# ------------------------------------------------------
# FFN 部分
# ------------------------------------------------------
# 保存一下 FFN 层的输入值,等会用于残差连接
residual = hidden_states
# 前置(pred)layer norm
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# 残差连接
hidden_states = residual + feed_forward_hidden_states
- return hidden_states
+ if use_cache:
+ return (hidden_states, present)
+ else:
+ return hidden_states
class GPT2Model(nn.Module):
""" 该类是堆叠 embedding 层以及多个 transformer 层 """
def __init__(self, config):
super().__init__()
self.embed_dim = config.hidden_size
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([GPT2Block(config) for _ in range(config.num_hidden_layers)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
- def forward(self, input_ids, attention_mask):
+ def forward(self, input_ids, attention_mask, past_key_values=None, use_cache=False):
- # position_ids 的维度为: [bs, seq_len]
- input_shape = input_ids.size()
- position_ids = torch.arange(input_shape[-1]).unsqueeze(0)
+
+ # 如果是预填充阶段,那么 position_ids 的维度为: [bs, seq_len]
+ # 如果是使用KVCache阶段,那么 position_ids 的维度为: [bs, 1],不过要注意 position_ids 里面的值不是0,而是 past_length
+ if past_key_values is not None:
+ past_length = past_key_values[0][0].size(-2)
+ else:
+ past_length = 0
+ input_shape = input_ids.size()
+ position_ids = torch.arange(start=past_length, end=input_shape[-1] + past_length).unsqueeze(0)
# Embedding层,这里只相加了token embedding和position embeddding,忽略了token type embedding
inputs_embed = self.wte(input_ids)
position_embed = self.wpe(position_ids)
# 如果是预填充阶段,那么 inputs_embed 的维度是: [bs, seq_len, embed_dim], position_embed 的维度是 [1, seq_len, embed_dim]
# 如果是使用KVCache阶段,那么 inputs_embed 的维度是: [bs, 1, embed_dim], position_embed 的维度是 [1, 1, embed_dim]
# 无论上述哪个阶段,这里 inputs_embed 和 position_embed 相加时都会在 dim=0 这个维度做广播
hidden_states = inputs_embed + position_embed
hidden_states = self.drop(hidden_states)
# 经过多个transformer层
- for block in self.h:
- hidden_states = block(hidden_states, attention_mask)
+
+ # past_key_values 中存储的是上次推理时保存的KVCache,本次推理时会使用这里面的变量;
+ # presents 中存储的是本次推理时的KVCache,这个会返回用于下次推理;
+ presents = [] if use_cache else None
+ for block, layer_past in zip(self.h, past_key_values):
+ if use_cache:
+ hidden_states, present = block(hidden_states, attention_mask, layer_past)
+ presents += [present, ]
+ else:
+ hidden_states = block(hidden_states, attention_mask, layer_past)
# 由于tranformer中使用的是pre layer norm,所以最后还需要过一层layer norm
hidden_states = self.ln_f(hidden_states)
- return hidden_states
+ if use_cache: # 如果使用KVCache,会把本次推理时的缓存返回,下次推理时使用
+ return (hidden_states, presents)
+ else:
+ return hidden_states
看完代码之后再来看这三个问题:
-
KVCache 减少了 multi-head attention 层中哪部分运算?
- 在预填充阶段阶段是没有任何优化的,和不使用KVCache的计算量相同,下面只分析使用KVCache阶段的阶段,并且假设推理时的 bs 都为 1;
- 在计算 xW_q、xW_k、xW_v 时,原来的 x 为 [1, seq_len, embed_dim],使用 KVCache 之后的 x 为 [1, 1, embed_dim],计算量减小;
- 在计算 \text{softmax}(\frac{Q \cdot K^{\top}}{\sqrt{d_k}}) \cdot V 时,K 和 V 的维度都是和原来一样,但是 Q 的维度变小了,其为 \text{seq_len}=1。
-
使用了 KVCache 之后,MLP层(也称FFN层)的运算有没有减少?
- MLP层的计算量也减小了,虽然在上述代码中
GPT2MLP
里面的代码没有任何修改,但是其forward()
函数的输入维度变了,原来是 [1, seq_len, ebemd_dim],使用 KVCache 之后是 [1, 1, ebemd_dim]。
- MLP层的计算量也减小了,虽然在上述代码中
-
为什么仅缓存 key 和 value,不缓存 query?
- 因为 query 仅需要使用最后一个token,与前面的token无关。