分析 self-attention 中 QKV 三个向量各自的作用#
首先 self-attention 的公式如下所示,这个都知道,但是 QKV 这三部分各自的作用一直不清晰,本文是对 QKV 各自具体的作用做一下分析。
1、分析公式#
为了后文的描述,先说明概念并定义符号。
在 attention 中,key 序列与 value 序列必然是同一个序列,而 query 序列却不一定与 key/value 序列是同一个序列。后文在描述时使用 "query序列" 和 "key/value序列" 这两种说法指代这两个序列。将 query 序列的长度定义为 L,将 key/value 序列的长度定义为 L',则 Q、K、V 这三个矩阵的维度为:(L, d), (L', d), (L', d),其中 d 是每个 token 对应的向量的维度。
1.1 计算注意力权重矩阵#
把矩阵 Q 和 K 中的每一行写成向量的形式,为:Q = \begin{bmatrix} \vec{q_1} \\ \vec{q_2} \\ \vdots \\ \vec{q_l} \\ \vdots \\ \vec{q_L} \end{bmatrix} 和 K = \begin{bmatrix} \vec{k_1} \\ \vec{k_2} \\ \vdots \\ \vec{k_{l'}} \\ \vdots \\ \vec{k_{L'}} \end{bmatrix},其中所有的元素 \vec{q_l} 和 \vec{k_{l'}} 都是维度为 d 的向量。将这两个矩阵相乘得到:
接下来就该对上述矩阵做 \text{softmax}(\cdot) 操作了,注意是对上述矩阵的每一行做 \text{softmax}(\cdot),变形过程如下式,该矩阵就是注意力权重矩阵了:
该矩阵的维度是 (L, L'),其中每一行对应的是 query 序列中某个 token 对 key/value 序列中所有 token 的注意力权重。把上述矩阵中的每一行对应的向量用一个符号 \vec{a_l} 来表示,那么上述注意力权重矩阵就变为了 \begin{bmatrix} \vec{a_1} \\ \vec{a_2} \\ \vdots \\ \vec{a_l} \\ \vdots \\ \vec{a_L} \end{bmatrix},其中 \vec{a_l} 是一个有 L' 个元素的向量,该向量是经过softmax之后得到的,即其中的每个值都介于 0~1 之间,且该向量的所有元素之和为1。\vec{a_l} 的含义为 query 序列中第 l 个 token 对 key/value 序列中所有 token 的注意力权重。
1.2 计算 self-attention 最终输出#
对于矩阵 V,和矩阵 K 一样,也是将每一行写成向量的形式,即:V = \begin{bmatrix} \vec{v_1} \\ \vec{v_2} \\ \vdots \\ \vec{v_{l'}} \\ \vdots \\ \vec{v_{L'}} \end{bmatrix},接下来就把上述注意力权重矩阵乘上矩阵 V,得到下式:
先看上述公式(4)中间那个等号右侧的矩阵,该矩阵总共有 L 行,对应 L 个向量,每个向量是 key/value 序列中每个 token 对应的隐向量 \vec{v} 按照注意力权重加权求和得到的。在最开始时就已经定义了 L 表示的是 query 序列的长度,所以在文本生成任务的背景下,公式(4)计算出的矩阵每行的向量的含义为:在 query 序列中,已知每个 t 时刻的 token,预测下一个 t+1 时刻的 token 时的 logits。也就是说,公式(4)中的 \vec{\text{logits}_1} 表示给定了 t=1 时刻的 token,预测下一个 t=2 时刻时计算出的 logits;\vec{\text{logits}_l} 表示给定了 t=l 时刻的 token,预测下一个 t=l+1 时刻时计算出的 logits。
所以可见,常规的 self-attention 的公式会把整个 query 序列中,给定每个 token,预测下一个 token 的所有 logits 都计算出来。对于 clm 任务,这在训练时是没有问题的,因为训练时是希望并行的,一次性把整个序列的 loss 全都计算出来。但是在推理时,我们仅想知道最后一个 token 的下一个 token 是什么,前面的并不想知道,所以除了最后一个 token 对应的 logits 需要计算以外其他的都不需要计算,常规的 self-attention 是有些计算资源的浪费的,不过使用了 KVCache 技术之后这部分浪费基本就被优化掉了。
2、分析代码#
依然还是使用 GPT2 的代码做分析,如果对原始的 GPT2 模型结构和代码不了解,请参考 GPT2 模型结构和实现代码。
与分析公式中相同,key 序列与 value 序列必然是同一个序列,而 query 序列却不一定与 key/value 序列是同一个序列。注释里描述时使用 "query序列" 和 "key/value序列" 这两种说法指代这两个序列。在下述代码的注释中使用 q_seq_len
用来说明 query 序列的长度,使用 kv_seq_len
用来说明 key/value 序列的长度。
对代码的说明都放在注释中了,在 forward()
函数中注释比代码还要长。
class GTP2Attention(nn.Module):
""" 该类为一个完整的 multi-head attention 结构 """
def __init__(self, config):
super().__init__()
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
self.c_attn = nn.Linear(self.embed_dim, 3 * self.embed_dim)
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
def forward(self, hidden_states, attention_mask=None):
# 将 W_q, W_k, W_v 三个矩阵合成一次矩阵乘法运算,乘完之后再分割开来
# query: [bs, q_seq_len, embed_dim]
# key: [bs, kv_seq_len, embed_dim]
# value: [bs, kv_seq_len, embed_dim]
query, key, value = self.c_attn(hidden_states).split(self.embed_dim, dim=2)
# ========================================================================
# 经过下面这几行代码的变形之后,query, key, value 的维度如下:
# query: [bs, num_heads, q_seq_len, head_dim]
# key : [bs, num_heads, kv_seq_len, head_dim]
# value: [bs, num_heads, kv_seq_len, head_dim]
#
# 在后面的分析中基本都会忽略掉 batch_size 以及多头部分,仅分析后两个维度
# ========================================================================
q_new_size = query.size()[:-1] + [self.num_heads, self.head_dim]
query = query.view(q_new_size)
kv_new_size = key.size()[:-1] + [self.num_heads, self.head_dim]
key, value = key.view(kv_new_size), value.view(kv_new_size)
query, key, value = query.permute(0, 2, 1, 3), key.permute(0, 2, 1, 3), value.permute(0, 2, 1, 3)
# ========================================================================
# 下面这行代码对应着上文的公式(2),当然这里没有除以根号d;
# 这行代码得到的注意力权重矩阵 att_weights 的维度为: [bs, num_heads, q_seq_len, kv_seq_len]
# ========================================================================
att_weights = torch.matmul(query, key.transpose(-1, -2))
# ========================================================================
# 对刚才得到的注意力矩阵除上一个根号d_k,然后再做softmax;
# 注意这里做 softmax 时是对最后一个维度做的,与上面的公式(3)是一致的;
# 经过 softmax 之后得到的就是最终的注意力权重矩阵了,维度还是: [bs, num_heads, q_seq_len, kv_seq_len]
# 忽略前两个维度之后,(q_seq_len, kv_seq_len) 表示的就是 query 序列中的 token 对 key/value 序列中的每个 token 的注意力权重
# ========================================================================
att_weights = att_weights / torch.full([], value.size()[-1] ** 0.5)
if attention_mask is not None:
att_weights = att_weights + attention_mask
att_weights = nn.functional.softmax(att_weights, dim=-1)
att_weights = self.attn_dropout(att_weights)
# ========================================================================
# 下述矩阵乘法的维度变换为:
# [bs, num_heads, q_seq_len, head_dim] = [bs, num_heads, q_seq_len, kv_seq_len] * [bs, num_heads, kv_seq_len, head_dim]
#
# 忽略前两个维度,只看后两个维度,即: [q_seq_len, head_dim] = [q_seq_len, kv_seq_len] * [kv_seq_len, head_dim]
#
# 这里的矩阵乘法的含义为:将 key/value 序列的 value 矩阵中每个隐层的向量按照 `query 序列中的每个 token 对 key/value 序列中每个 token 的关注权重` 进行加权求和。
#
# 这里的 `query 序列中的每个 token 对 key/value 序列中每个 token 的关注权重` 就是上面解释的注意力权重矩阵 att_weights 的含义。
#
# 最终得到的矩阵 att_output 的维度为: [bs, num_heads, q_seq_len, head_dim],忽略前两个维度,只看 (q_seq_len, head_dim),这个矩阵的其实就是:
#
# 对 query 序列中的每个字符生成下一个字符时的 logits 向量
# ========================================================================
att_output = torch.matmul(att_weights, value)
# 注意:这里不能够直接将 att_output 由 [bs, num_heads, seq_len, head_dim] 变为 [bs, seq_len, embed_dim]
# [bs, num_heads, seq_len, head_dim] -> [bs, seq_len, num_heads, head_dim] -> [bs, seq_len, embed_dim]
att_output = att_output.permute(0, 2, 1, 3)
att_output = att_output.view(att_output.size()[:-2] + [self.num_heads * self.head_dim, ])
att_output = self.c_proj(att_output)
att_output = self.resid_dropout(att_output)
return att_output