[隐藏左侧目录栏][显示左侧目录栏]

分析 self-attention 中 QKV 三个向量各自的作用#

首先 self-attention 的公式如下所示,这个都知道,但是 QKV 这三部分各自的作用一直不清晰,本文是对 QKV 各自具体的作用做一下分析。

\begin{equation}\text{Attention} = \text{softmax}(\frac{Q K^T}{\sqrt{d}}) \cdot V\end{equation}

1、分析公式#

为了后文的描述,先说明概念并定义符号。

在 attention 中,key 序列与 value 序列必然是同一个序列,而 query 序列却不一定与 key/value 序列是同一个序列。后文在描述时使用 "query序列" 和 "key/value序列" 这两种说法指代这两个序列。将 query 序列的长度定义为 L,将 key/value 序列的长度定义为 L',则 QKV 这三个矩阵的维度为:(L, d), (L', d), (L', d),其中 d 是每个 token 对应的向量的维度。

1.1 计算注意力权重矩阵#

把矩阵 QK 中的每一行写成向量的形式,为:Q = \begin{bmatrix} \vec{q_1} \\ \vec{q_2} \\ \vdots \\ \vec{q_l} \\ \vdots \\ \vec{q_L} \end{bmatrix}K = \begin{bmatrix} \vec{k_1} \\ \vec{k_2} \\ \vdots \\ \vec{k_{l'}} \\ \vdots \\ \vec{k_{L'}} \end{bmatrix},其中所有的元素 \vec{q_l}\vec{k_{l'}} 都是维度为 d 的向量。将这两个矩阵相乘得到:

\begin{equation}\begin{split} \frac{1}{\sqrt{d}} \cdot Q \cdot K^T &= \frac{1}{\sqrt{d}} \cdot \begin{bmatrix} \vec{q_1} \\ \vec{q_2} \\ \vdots \\ \vec{q_l} \\ \vdots \\ \vec{q_L} \end{bmatrix} \cdot [\vec{k_1^T}, \vec{k_2^T}, ..., \vec{k_{l'}^T}, ..., \vec{k_{L'}^T}] \\ &= \frac{1}{\sqrt{d}} \cdot \begin{bmatrix} \vec{q_1}\vec{k_1^T} & \vec{q_1}\vec{k_2^T} & ... & \vec{q_1}\vec{k_{l'}^T} & ... & \vec{q_1}\vec{k_{L'}^T} \\ \vec{q_2}\vec{k_1^T} & \vec{q_2}\vec{k_2^T} & ... & \vec{q_2}\vec{k_{l'}^T} & ... & \vec{q_2}\vec{k_{L'}^T} \\ & & \vdots \\ \vec{q_l}\vec{k_1^T} & \vec{q_l}\vec{k_2^T} & ... & \vec{q_l}\vec{k_{l'}^T} & ... & \vec{q_l}\vec{k_{L'}^T} \\ & & \vdots \\ \vec{q_L}\vec{k_1^T} & \vec{q_L}\vec{k_2^T} & ... & \vec{q_L}\vec{k_{l'}^T} & ... & \vec{q_L}\vec{k_{L'}^T} \end{bmatrix} \end{split}\end{equation}

接下来就该对上述矩阵做 \text{softmax}(\cdot) 操作了,注意是对上述矩阵的每一行做 \text{softmax}(\cdot),变形过程如下式,该矩阵就是注意力权重矩阵了:

\begin{equation}\begin{split} \text{att weight matrix} = \text{softmax}(\frac{1}{\sqrt{d}} \cdot Q \cdot K^T) &= \begin{bmatrix} \text{softmax}(\begin{bmatrix}\vec{q_1}\vec{k_1^T} & \vec{q_1}\vec{k_2^T} & ... & \vec{q_1}\vec{k_{l'}^T} & ... & \vec{q_1}\vec{k_{L'}^T}\end{bmatrix}) \\ \text{softmax}(\begin{bmatrix}\vec{q_2}\vec{k_1^T} & \vec{q_2}\vec{k_2^T} & ... & \vec{q_2}\vec{k_{l'}^T} & ... & \vec{q_2}\vec{k_{L'}^T}\end{bmatrix}) \\ \vdots \\ \text{softmax}(\begin{bmatrix}\vec{q_l}\vec{k_1^T} & \vec{q_l}\vec{k_2^T} & ... & \vec{q_l}\vec{k_{l'}^T} & ... & \vec{q_l}\vec{k_{L'}^T}\end{bmatrix}) \\ \vdots \\ \text{softmax}(\begin{bmatrix}\vec{q_L}\vec{k_1^T} & \vec{q_L}\vec{k_2^T} & ... & \vec{q_L}\vec{k_{l'}^T} & ... & \vec{q_L}\vec{k_{L'}^T}\end{bmatrix}) \\ \end{bmatrix} \\ &= \begin{bmatrix} a_{11} & a_{12} & ... & a_{1l'} & ... & a_{1L'} \\ a_{21} & a_{22} & ... & a_{2l'} & ... & a_{2L'} \\ & & \vdots \\ a_{l1} & a_{l2} & ... & a_{ll'} & ... & a_{lL'} \\ & & \vdots \\ a_{L1} & a_{L2} & ... & a_{Ll'} & ... & a_{LL'} \\ \end{bmatrix} \end{split}\end{equation}

该矩阵的维度是 (L, L'),其中每一行对应的是 query 序列中某个 token 对 key/value 序列中所有 token 的注意力权重。把上述矩阵中的每一行对应的向量用一个符号 \vec{a_l} 来表示,那么上述注意力权重矩阵就变为了 \begin{bmatrix} \vec{a_1} \\ \vec{a_2} \\ \vdots \\ \vec{a_l} \\ \vdots \\ \vec{a_L} \end{bmatrix},其中 \vec{a_l} 是一个有 L' 个元素的向量,该向量是经过softmax之后得到的,即其中的每个值都介于 0~1 之间,且该向量的所有元素之和为1。\vec{a_l} 的含义为 query 序列中第 l 个 token 对 key/value 序列中所有 token 的注意力权重。

1.2 计算 self-attention 最终输出#

对于矩阵 V,和矩阵 K 一样,也是将每一行写成向量的形式,即:V = \begin{bmatrix} \vec{v_1} \\ \vec{v_2} \\ \vdots \\ \vec{v_{l'}} \\ \vdots \\ \vec{v_{L'}} \end{bmatrix},接下来就把上述注意力权重矩阵乘上矩阵 V,得到下式:

\begin{equation}\begin{split} \text{att weight matrix} \cdot V &= \begin{bmatrix} a_{11} & a_{12} & ... & a_{1l'} & ... & a_{1L'} \\ a_{21} & a_{22} & ... & a_{2l'} & ... & a_{2L'} \\ & & \vdots \\ a_{l1} & a_{l2} & ... & a_{ll'} & ... & a_{lL'} \\ & & \vdots \\ a_{L1} & a_{L2} & ... & a_{Ll'} & ... & a_{LL'} \\ \end{bmatrix} \cdot \begin{bmatrix} \vec{v_1} \\ \vec{v_2} \\ \vdots \\ \vec{v_{l'}} \\ \vdots \\ \vec{v_{L'}} \end{bmatrix} \\ &= \begin{bmatrix} a_{11}\vec{v_1} + a_{12}\vec{v_2} + ... + a_{1l'}\vec{v_{l'}} + ... + a_{1L'}\vec{v_{L'}} \\ a_{21}\vec{v_1} + a_{22}\vec{v_2} + ... + a_{2l'}\vec{v_{l'}} + ... + a_{2L'}\vec{v_{L'}} \\ \vdots \\ a_{l1}\vec{v_1} + a_{l2}\vec{v_2} + ... + a_{ll'}\vec{v_{l'}} + ... + a_{lL'}\vec{v_{L'}} \\ \vdots \\ a_{L1}\vec{v_1} + a_{L2}\vec{v_2} + ... + a_{Ll'}\vec{v_{l'}} + ... + a_{LL'}\vec{v_{L'}} \end{bmatrix} \\ &= \begin{bmatrix} \vec{\text{logits}_1} \\ \vec{\text{logits}_2} \\ \vdots \\ \vec{\text{logits}_l} \\ \vdots \\ \vec{\text{logits}_L} \end{bmatrix} \end{split}\end{equation}

先看上述公式(4)中间那个等号右侧的矩阵,该矩阵总共有 L 行,对应 L 个向量,每个向量是 key/value 序列中每个 token 对应的隐向量 \vec{v} 按照注意力权重加权求和得到的。在最开始时就已经定义了 L 表示的是 query 序列的长度,所以在文本生成任务的背景下,公式(4)计算出的矩阵每行的向量的含义为:在 query 序列中,已知每个 t 时刻的 token,预测下一个 t+1 时刻的 token 时的 logits。也就是说,公式(4)中的 \vec{\text{logits}_1} 表示给定了 t=1 时刻的 token,预测下一个 t=2 时刻时计算出的 logits;\vec{\text{logits}_l} 表示给定了 t=l 时刻的 token,预测下一个 t=l+1 时刻时计算出的 logits。

所以可见,常规的 self-attention 的公式会把整个 query 序列中,给定每个 token,预测下一个 token 的所有 logits 都计算出来。对于 clm 任务,这在训练时是没有问题的,因为训练时是希望并行的,一次性把整个序列的 loss 全都计算出来。但是在推理时,我们仅想知道最后一个 token 的下一个 token 是什么,前面的并不想知道,所以除了最后一个 token 对应的 logits 需要计算以外其他的都不需要计算,常规的 self-attention 是有些计算资源的浪费的,不过使用了 KVCache 技术之后这部分浪费基本就被优化掉了。

2、分析代码#

依然还是使用 GPT2 的代码做分析,如果对原始的 GPT2 模型结构和代码不了解,请参考 GPT2 模型结构和实现代码

与分析公式中相同,key 序列与 value 序列必然是同一个序列,而 query 序列却不一定与 key/value 序列是同一个序列。注释里描述时使用 "query序列" 和 "key/value序列" 这两种说法指代这两个序列。在下述代码的注释中使用 q_seq_len 用来说明 query 序列的长度,使用 kv_seq_len 用来说明 key/value 序列的长度。

对代码的说明都放在注释中了,在 forward() 函数中注释比代码还要长。

class GTP2Attention(nn.Module):
    """ 该类为一个完整的 multi-head attention 结构 """

    def __init__(self, config):
        super().__init__()

        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        self.split_size = self.embed_dim

        self.c_attn = nn.Linear(self.embed_dim, 3 * self.embed_dim)
        self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)

        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)

    def forward(self, hidden_states, attention_mask=None):
        # 将 W_q, W_k, W_v 三个矩阵合成一次矩阵乘法运算,乘完之后再分割开来
        # query: [bs, q_seq_len, embed_dim]
        # key:   [bs, kv_seq_len, embed_dim]
        # value: [bs, kv_seq_len, embed_dim]
        query, key, value = self.c_attn(hidden_states).split(self.embed_dim, dim=2)

        # ========================================================================
        # 经过下面这几行代码的变形之后,query, key, value 的维度如下:
        #   query: [bs, num_heads, q_seq_len, head_dim]
        #   key  : [bs, num_heads, kv_seq_len, head_dim]
        #   value: [bs, num_heads, kv_seq_len, head_dim]
        #
        # 在后面的分析中基本都会忽略掉 batch_size 以及多头部分,仅分析后两个维度
        # ========================================================================
        q_new_size = query.size()[:-1] + [self.num_heads, self.head_dim]
        query = query.view(q_new_size)
        kv_new_size = key.size()[:-1] + [self.num_heads, self.head_dim]
        key, value = key.view(kv_new_size), value.view(kv_new_size)
        query, key, value = query.permute(0, 2, 1, 3), key.permute(0, 2, 1, 3), value.permute(0, 2, 1, 3)

        # ========================================================================
        # 下面这行代码对应着上文的公式(2),当然这里没有除以根号d;
        # 这行代码得到的注意力权重矩阵 att_weights 的维度为: [bs, num_heads, q_seq_len, kv_seq_len]
        # ========================================================================
        att_weights = torch.matmul(query, key.transpose(-1, -2))

        # ========================================================================
        # 对刚才得到的注意力矩阵除上一个根号d_k,然后再做softmax;
        # 注意这里做 softmax 时是对最后一个维度做的,与上面的公式(3)是一致的;
        # 经过 softmax 之后得到的就是最终的注意力权重矩阵了,维度还是: [bs, num_heads, q_seq_len, kv_seq_len]
        # 忽略前两个维度之后,(q_seq_len, kv_seq_len) 表示的就是 query 序列中的 token 对 key/value 序列中的每个 token 的注意力权重
        # ========================================================================
        att_weights = att_weights / torch.full([], value.size()[-1] ** 0.5)
        if attention_mask is not None:
            att_weights = att_weights + attention_mask
        att_weights = nn.functional.softmax(att_weights, dim=-1)
        att_weights = self.attn_dropout(att_weights)

        # ========================================================================
        # 下述矩阵乘法的维度变换为: 
        #   [bs, num_heads, q_seq_len, head_dim] = [bs, num_heads, q_seq_len, kv_seq_len] * [bs, num_heads, kv_seq_len, head_dim]
        # 
        # 忽略前两个维度,只看后两个维度,即: [q_seq_len, head_dim] = [q_seq_len, kv_seq_len] * [kv_seq_len, head_dim]
        # 
        # 这里的矩阵乘法的含义为:将 key/value 序列的 value 矩阵中每个隐层的向量按照 `query 序列中的每个 token 对 key/value 序列中每个 token 的关注权重` 进行加权求和。
        #
        # 这里的  `query 序列中的每个 token 对 key/value 序列中每个 token 的关注权重` 就是上面解释的注意力权重矩阵 att_weights 的含义。
        #
        # 最终得到的矩阵 att_output 的维度为: [bs, num_heads, q_seq_len, head_dim],忽略前两个维度,只看 (q_seq_len, head_dim),这个矩阵的其实就是:
        # 
        #    对 query 序列中的每个字符生成下一个字符时的 logits 向量
        # ========================================================================
        att_output = torch.matmul(att_weights, value)

        # 注意:这里不能够直接将 att_output 由 [bs, num_heads, seq_len, head_dim] 变为 [bs, seq_len, embed_dim]
        # [bs, num_heads, seq_len, head_dim] -> [bs, seq_len, num_heads, head_dim] -> [bs, seq_len, embed_dim]
        att_output = att_output.permute(0, 2, 1, 3)
        att_output = att_output.view(att_output.size()[:-2] + [self.num_heads * self.head_dim, ])

        att_output = self.c_proj(att_output)
        att_output = self.resid_dropout(att_output)

        return att_output

Reference#