[隐藏左侧目录栏][显示左侧目录栏]

KVCache#

关于该技术知乎的 Young 的回答 非常清晰易懂,推荐去看一看。本文的中文描述基本都是来自该回答。

1、KVCache是啥?#

大模型推理性能优化的一个常用技术是KVCache,该技术可以在不影响任何计算精度的前提下,通过空间换时间思想,提高推理性能。

2、背景#

生成式generative模型的推理过程很有特点,我们给一个输入文本,模型会输出一个回答(长度为N),其实该过程中执行了N次推理过程。即GPT类模型一次推理只输出一个token,输出token会与输入tokens 拼接在一起,然后作为下一次推理的输入,这样不断反复直到遇到终止符。

3、原理#

在上面的推理过程中,每 step 内,输入一个 token 序列,经过 Embedding 层将输入 token 序列变为一个三维张量[bs, seq_len, embed_dim],经过一通计算,最后经 logits 层将计算结果映射至词表空间,输出张量维度为[bs, seq_len, vocab_size]。

当前轮输出 token 与输入 tokens 拼接,并作为下一轮的输入 tokens,反复多次。可以看出第 i+1 轮输入数据只比第 i 轮输入数据新增了一个token,其他全部相同。因此第 i+1 轮推理时必然包含了第 i 轮的部分计算。KVCache 的出发点就在这里,缓存当前轮可重复利用的计算结果,下一轮计算时直接读取缓存结果。

4、调用层面实现细节#

目前各大模型推理都实现了 KVCache,下面就看如何使用了。增加了 KVCache 之后主要改动:

  • 在推理时新增了 past_key_values 参数,该参数就会以追加方式保存每一轮的 (k,v) 值。past_key_values 变量内容为((k,v), (k,v), ..., (k,v)),即有 n_{\text{layers}} 个 (k,v) 组成的一个元组,其中 k 和 v 的维度均为 [bs, num_heads, seq_len, head_dims]。这里可以顺带计算出每轮推理对应的 cache 数据量为 2*\text{bs}*\text{seq_len}*\text{num_heads}*\text{head_dims}*n_{\text{layers}}。以GPT3-175B为例,假设以 float16 来保存 KVCache,senquence长度为100,batch_size=1,则 KVCache占用显存为 (2×1×100×12288×96)×2Byte= 472MB。

  • 推理输出的 token 直接作为下一轮的输入,不再拼接,因为上文信息已经在 KVCache 中。

如果使用像 huggingface 的 transformers 库中提供的 model.generate() 函数,那么基本什么都不用做就启用了 KVCache,这里是为了了解其原理,所以不调用该函数,而是使用 model 的 forward() 函数。下面是在推理时不使用 KVCache 和使用 KVCache 的代码对比:

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

model_name_or_path = "/path/for/gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name_or_path, torchscript=True).eval()
tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path)

# 不使用 KVCache 时的推理代码
in_text = "Lionel Messi is a"
in_tokens = torch.tensor(tokenizer.encode(in_text))
token_eos = torch.tensor([198]) # 终止符,遇到该字符就结束推理
out_token = None
with torch.no_grad():
    while out_token != token_eos:
        logits, _ = model(in_tokens)
        out_token = torch.argmax(logits[-1, :], dim=0, keepdim=True)
        in_tokens = torch.cat((in_tokens, out_token), 0)  # 每次都把之前的所有token与推理得到的新token拼接起来作为下次的输入
        text = tokenizer.decode(in_tokens)
out_text = tokenizer.decode(in_tokens)
print(out_text)

# 使用 KVCache 时的推理代码
in_text = "Lionel Messi is a"
in_tokens = torch.tensor(tokenizer.encode(in_text))
token_eos = torch.tensor([198]) # 终止符,遇到该字符就结束推理
out_token = None
kvcache = None
out_text = in_text
with torch.no_grad():
    while out_token != token_eos:
        logits, kvcache = model(in_tokens, past_key_values=kvcache)  # 这里需要传入上次推理缓存的(k,v)
        out_token = torch.argmax(logits[-1, :], dim=0, keepdim=True)
        in_tokens = out_token  # 每次输出的out_token直接作为下一次的输入,这个不需要再和之前的in_tokens做拼接了
        text = tokenizer.decode(in_tokens)
        out_text += text
print(out_text)

5、底层实现细节#

KV Cache 配置开启后,推理过程可以分为2个阶段:

  • 预填充阶段:发生在计算第一个输出token过程中,这时Cache是空的,计算时需要为每个 transformer layer 计算并保存 key cache 和 value cache,在输出 token 时 Cache 完成填充。

  • 使用KV Cache阶段:发生在计算第二个输出token至最后一个token过程中,这时Cache是有值的,每轮推理只需读取Cache,同时将当前轮计算出的新的 Key、Value 追加写入至Cache。

下面是以GPT2模型为例,展示了如果在原始的GPT2模型上添加 KVCache 功能需要修改哪些代码。

如果对原始的 GPT2 模型结构和代码不了解,请参考 GPT2 模型结构和实现代码,本文中 GPT2 的原始代码就是取自该文。下述代码中注释已经足够清晰。同样的,该代码仅是用于学习理解,其在语法上有不少的bug,也是跑不通的;同时该代码不一定是最优雅的实现,而是尽量选择容易理解的实现方式。

阅读并理解该代码时可以思考以下问题:

  • KVCache 减少了 multi-head attention 层中哪部分运算?
  • 使用了 KVCache 之后,MLP层(也称FFN层)的运算有没有减少?
  • 为什么仅缓存 key 和 value,不缓存 query?
  import torch
  from torch import nn

  class GTP2Attention(nn.Module):
      """ 该类为一个完整的 multi-head attention 结构 """

      def __init__(self, config):
          super().__init__()

          self.embed_dim = config.hidden_size
          self.num_heads = config.num_attention_heads
          self.head_dim = self.embed_dim // self.num_heads
          self.split_size = self.embed_dim

          self.c_attn = nn.Linear(self.embed_dim, 3 * self.embed_dim)
          self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)

          self.attn_dropout = nn.Dropout(config.attn_pdrop)
          self.resid_dropout = nn.Dropout(config.resid_pdrop)

-     def forward(self, hidden_states, attention_mask=None):
+     def forward(self, hidden_states, attention_mask=None, layer_past=None, use_cache=False):
          # 如果是预填充阶段,那么hidden_states的维度是: [bs, seq_len, embed_dim]
          # 如果是使用KVCache阶段,那么hidden_states的维度是: [bs, 1, embed_dim]

          # 将 W_q, W_k, W_v 三个矩阵合成一次矩阵乘法运算,乘完之后再分割开来
          # 如果是预填充阶段,那么query, key, value的维度都是: [bs, seq_len, embed_dim]
          # 如果是使用KVCache阶段,那么query, key, value的维度都是: [bs, 1, embed_dim]
          query, key, value = self.c_attn(hidden_states).split(self.embed_dim, dim=2)

          # 如果是预填充阶段,那么query, key, value的维度变化为:
          #       [bs, seq_len, embed_dim] -> [bs, seq_len, num_heads, head_dim] -> [bs, num_heads, seq_len, head_dim]
          # 如果是使用KVCache阶段,那么query, key, value的维度变化为:
          #       [bs, 1, embed_dim] -> [bs, 1, num_heads, head_dim] -> [bs, num_heads, 1, head_dim] 
          new_size = query.size()[:-1] + [self.num_heads, self.head_dim]
          query, key, value = query.view(new_size), key.view(new_size), value.view(new_size)
          query, key, value = query.permute(0, 2, 1, 3), key.permute(0, 2, 1, 3), value.permute(0, 2, 1, 3)

+         if layer_past is not None:
+             past_key, past_value = layer_past
+ 
+             # past_key: [bs, num_heads, past_seq_len, head_dim]
+             # key: [bs, num_heads, 1, head_dim] -> [bs, num_heads, seq_len, head_dim],其中 seq_len = past_seq_len + 1
+             key = torch.cat((past_key, key), dim=-2)
+ 
+             # past_value: [bs, num_heads, past_seq_len, head_dim]
+             # value: [bs, num_heads, 1, head_dim] -> [bs, num_heads, seq_len, head_dim],其中 seq_len = past_seq_len + 1
+             value = torch.cat((past_value, value), dim=-2)

+         present = None
+         if use_cache:
+             present = (key, value)

          # 如果是预填充阶段,那么注意力权重矩阵att_weights的维度为: [bs, num_heads, seq_len, seq_len]
          # 如果是使用KVCache阶段,那么注意力权重矩阵att_weights的维度为: [bs, num_heads, 1, seq_len]
          att_weights = torch.matmul(query, key.transpose(-1, -2))

          # 对注意力矩阵除上一个根号d_k,然后再做softmax
          att_weights = att_weights / torch.full([], value.size()[-1] ** 0.5)
          if attention_mask is not None:
              att_weights = att_weights + attention_mask
          att_weights = nn.functional.softmax(att_weights, dim=-1)
          att_weights = self.attn_dropout(att_weights)

          # 如果是预填充阶段,那么att_output的维度为: 
          #      [bs, num_heads, seq_len, head_dim] = [bs, num_heads, seq_len, seq_len] * [bs, num_heads, seq_len, head_dim]
          # 如果是使用KVCache阶段,那么att_output的维度为: 
          #      [bs, num_heads, 1, head_dim] = [bs, num_heads, 1, seq_len] * [bs, num_heads, seq_len, head_dim]
          att_output = torch.matmul(att_weights, value)

          # 如果是预填充阶段,那么维度变化是: 
          #      [bs, num_heads, seq_len, head_dim] -> [bs, seq_len, num_heads, head_dim] -> [bs, seq_len, embed_dim]
          # 如果是使用KVCache阶段,那么维度变化是: 
          #      [bs, num_heads, 1, head_dim] -> [bs, 1, num_heads, head_dim] -> [bs, 1, embed_dim]
          att_output = att_output.permute(0, 2, 1, 3)
          att_output = att_output.view(att_output.size()[:-2] + [self.num_heads * self.head_dim, ])

          # 
          att_output = self.c_proj(att_output)
          att_output = self.resid_dropout(att_output)

-         return att_output
+         if use_caceh:
+             return (att_output, present)
+         else:
+             return att_output


  class GPT2MLP(nn.Module):
      """ 该类为一个完整的 FFN 结构 """

      def __init__(self, intermediate_size, config):
          super().__init__()

          self.embed_dim = config.hidden_size

          # 一般情况下 intermediate_size = 4 * self.embed_dim,所以这里是一个先升维,后降维的过程
          self.c_fc = nn.Linear(self.embed_dim, intermediate_size)
          self.c_proj = nn.Linear(intermediate_size, self.embed_dim)
          self.act = ... # TODO 这里初始化一个激活函数,比如 ReLU、GELU 等
          self.dropout = nn.Dropout(config.resid_pdrop)

      def forward(self, hidden_states, ):
          hidden_states = self.c_fc(hidden_states)
          hidden_states = self.act(hidden_states)
          hidden_states = self.c_proj(hidden_states)
          hidden_states = self.dropout(hidden_states)
          return hidden_states


  class GPT2Block(nn.Module):
      """ 该类为一个完整的 transformer 结构 """

      def __init__(self, config):
          super().__init__()

          hidden_size = config.hidden_size
          inner_dim = 4 * hidden_size

          self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
          self.att = GTP2Attention(config)
          self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
          self.mlp = GPT2MLP(inner_dim, config)

-     def forward(self, hidden_states, attention_mask=None):
+     def forward(self, hidden_states, attention_mask=None, layer_past=None, use_cache=False):
          # 如果是预填充阶段,那么 hidden_states 维度为: [bs, seq_len, embed_dim]
          # 如果是使用KVCache阶段,那么 hidden_states 维度为: [bs, 1, embed_dim]
          # 相应的,在该函数中下面的变量 residual, attn_outputs, feed_forward_hidden_states 维度都与 hidden_states 是相同的

          # ------------------------------------------------------
          # multi-head attention 部分
          # ------------------------------------------------------
          # 保存一下 multi-head attention 层的输入值,等会用于残差连接
          residual = hidden_states
          # 前置(pred)layer norm
          hidden_states = self.ln_1(hidden_states)
-         attn_outputs = self.att(hidden_states, attention_mask)
+         if use_cache:
+             attn_outputs, present = self.att(hidden_states, attention_mask, layer_past)
+         else:
+             attn_outputs = self.att(hidden_states, sttention_mask, layer_past)
          # 残差连接
          hidden_states = attn_outputs + residual

          # ------------------------------------------------------
          # FFN 部分
          # ------------------------------------------------------
          # 保存一下 FFN 层的输入值,等会用于残差连接
          residual = hidden_states
          # 前置(pred)layer norm
          hidden_states = self.ln_2(hidden_states)
          feed_forward_hidden_states = self.mlp(hidden_states)
          # 残差连接
          hidden_states = residual + feed_forward_hidden_states

-         return hidden_states
+         if use_cache:
+             return (hidden_states, present)
+         else:
+             return hidden_states


  class GPT2Model(nn.Module):
      """ 该类是堆叠 embedding 层以及多个 transformer 层 """

      def __init__(self, config):
          super().__init__()

          self.embed_dim = config.hidden_size

          self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
          self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
          self.drop = nn.Dropout(config.embd_pdrop)

          self.h = nn.ModuleList([GPT2Block(config) for _ in range(config.num_hidden_layers)])
          self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

-     def forward(self, input_ids, attention_mask):
+     def forward(self, input_ids, attention_mask, past_key_values=None, use_cache=False):

-         # position_ids 的维度为: [bs, seq_len]
-         input_shape = input_ids.size()
-         position_ids = torch.arange(input_shape[-1]).unsqueeze(0)
+
+         # 如果是预填充阶段,那么 position_ids 的维度为: [bs, seq_len]
+         # 如果是使用KVCache阶段,那么 position_ids 的维度为: [bs, 1],不过要注意 position_ids 里面的值不是0,而是 past_length
+         if past_key_values is not None:
+             past_length = past_key_values[0][0].size(-2)
+         else:
+             past_length = 0
+         input_shape = input_ids.size()
+         position_ids = torch.arange(start=past_length, end=input_shape[-1] + past_length).unsqueeze(0)

          # Embedding层,这里只相加了token embedding和position embeddding,忽略了token type embedding
          inputs_embed = self.wte(input_ids)
          position_embed = self.wpe(position_ids)
          # 如果是预填充阶段,那么 inputs_embed 的维度是: [bs, seq_len, embed_dim], position_embed 的维度是 [1, seq_len, embed_dim]
          # 如果是使用KVCache阶段,那么 inputs_embed 的维度是: [bs, 1, embed_dim], position_embed 的维度是 [1, 1, embed_dim]
          # 无论上述哪个阶段,这里 inputs_embed 和 position_embed 相加时都会在 dim=0 这个维度做广播
          hidden_states = inputs_embed + position_embed
          hidden_states = self.drop(hidden_states)

          # 经过多个transformer层
-         for block in self.h:
-             hidden_states = block(hidden_states, attention_mask)
+
+         # past_key_values 中存储的是上次推理时保存的KVCache,本次推理时会使用这里面的变量;
+         # presents 中存储的是本次推理时的KVCache,这个会返回用于下次推理;
+         presents = [] if use_cache else None
+         for block, layer_past in zip(self.h, past_key_values):
+             if use_cache:
+                 hidden_states, present = block(hidden_states, attention_mask, layer_past)
+                 presents += [present, ]
+             else:
+                 hidden_states = block(hidden_states, attention_mask, layer_past)

          # 由于tranformer中使用的是pre layer norm,所以最后还需要过一层layer norm
          hidden_states = self.ln_f(hidden_states)

-         return hidden_states
+         if use_cache:  # 如果使用KVCache,会把本次推理时的缓存返回,下次推理时使用
+             return (hidden_states, presents)
+         else:
+             return hidden_states

看完代码之后再来看这三个问题:

  • KVCache 减少了 multi-head attention 层中哪部分运算?

    • 在预填充阶段阶段是没有任何优化的,和不使用KVCache的计算量相同,下面只分析使用KVCache阶段的阶段,并且假设推理时的 bs 都为 1;
    • 在计算 xW_qxW_kxW_v 时,原来的 x 为 [1, seq_len, embed_dim],使用 KVCache 之后的 x 为 [1, 1, embed_dim],计算量减小;
    • 在计算 \text{softmax}(\frac{Q \cdot K^{\top}}{\sqrt{d_k}}) \cdot V 时,KV 的维度都是和原来一样,但是 Q 的维度变小了,其为 \text{seq_len}=1
  • 使用了 KVCache 之后,MLP层(也称FFN层)的运算有没有减少?

    • MLP层的计算量也减小了,虽然在上述代码中 GPT2MLP 里面的代码没有任何修改,但是其 forward() 函数的输入维度变了,原来是 [1, seq_len, ebemd_dim],使用 KVCache 之后是 [1, 1, ebemd_dim]。
  • 为什么仅缓存 key 和 value,不缓存 query?

    • 因为 query 仅需要使用最后一个token,与前面的token无关。

Reference#