[隐藏左侧目录栏][显示左侧目录栏]

GPT2 模型结构和实现代码#

GPT2 的模型结构相比于 GPT 是有着少许改变的,后续的 GPT3 等模型也是在 GPT2 的基础上改出来的,所以详细了解 GPT2 的整个模型结构的细节是很有必要的。

在huggingface的项目 transformers 中对 GPT2 有着完整的实现,不过该项目毕竟是一个庞大的工程项目,里面有着各种功能的代码,不利于阅读。下面的代码是从 transformers 中的 GPT2 的实现中仅把模型结构部分的代码摘取出来之后的结果。需要注意的是,下述代码仅用于学习理解,其在语法上有着不少的bug,是跑不通的;同时该代码不一定是最优雅的实现,而是尽量选择容易理解的实现方式。

模型中涉及的几个公式如下所示:

\begin{equation}\text{Attention}(x) = \text{softmax}(\frac{xW_q \cdot xW_k}{\sqrt{d_k}}) \cdot xW_v\end{equation}
\begin{equation}\text{FFN}(x) = \text{act}(xW_1)W_2\end{equation}
\begin{equation}\text{transformer}(x) = \text{residual}(\text{FFN}(\text{LN}(\text{residual}(\text{Attention}(\text{LN}(x))))))\end{equation}

上述公式中 \text{act}(\cdot) 表示激活函数。

下述代码中有四个类:

  • GTP2Attention: 该类实现的功能是 multi-head attention 部分的功能;

  • GPT2MLP: 该类实现的功能是 FFN 部分的功能;

  • GPT2Block: 该类堆叠了 multi-head attention 和 FFN,也就是实现了 transformer 层;

  • GPT2Model: 该类是堆叠了 embedding 层和多个 tranformer 层,也就是一个完整的 decoder-only 的模型结构;

主要部分的代码都已经有了注释,比较容易理解:

import torch
from torch import nn


class GTP2Attention(nn.Module):
    """ 该类为一个完整的 multi-head attention 结构 """

    def __init__(self, config):
        super().__init__()

        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        self.split_size = self.embed_dim

        self.c_attn = nn.Linear(self.embed_dim, 3 * self.embed_dim)
        self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)

        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)

    def forward(self, hidden_states, attention_mask=None):
        # hidden_states: [bs, seq_len, embed_dim]

        # 将 W_q, W_k, W_v 三个矩阵合成一次矩阵乘法运算,乘完之后再分割开来
        # query: [bs, seq_len, embed_dim]
        # key:   [bs, seq_len, embed_dim]
        # value: [bs, seq_len, embed_dim]
        query, key, value = self.c_attn(hidden_states).split(self.embed_dim, dim=2)

        # [bs, seq_len, embed_dim] -> [bs, seq_len, num_heads, head_dim] -> [bs, num_heads, seq_len, head_dim]
        new_size = query.size()[:-1] + [self.num_heads, self.head_dim]
        query, key, value = query.view(new_size), key.view(new_size), value.view(new_size)
        query, key, value = query.permute(0, 2, 1, 3), key.permute(0, 2, 1, 3), value.permute(0, 2, 1, 3)

        # 注意力权重矩阵att_weights: [bs, num_heads, seq_len, seq_len]
        att_weights = torch.matmul(query, key.transpose(-1, -2))

        # 对注意力矩阵除上一个根号d_k,然后再做softmax
        att_weights = att_weights / torch.full([], value.size()[-1] ** 0.5)
        if attention_mask is not None:
            att_weights = att_weights + attention_mask
        att_weights = nn.functional.softmax(att_weights, dim=-1)
        att_weights = self.attn_dropout(att_weights)

        # att_output: [bs, num_heads, seq_len, head_dim] = [bs, num_heads, seq_len, seq_len] * [bs, num_heads, seq_len, head_dim]
        att_output = torch.matmul(att_weights, value)

        # 注意:这里不能够直接将 att_output 由 [bs, num_heads, seq_len, head_dim] 变为 [bs, seq_len, embed_dim]
        # [bs, num_heads, seq_len, head_dim] -> [bs, seq_len, num_heads, head_dim] -> [bs, seq_len, embed_dim]
        att_output = att_output.permute(0, 2, 1, 3)
        att_output = att_output.view(att_output.size()[:-2] + [self.num_heads * self.head_dim, ])

        # 
        att_output = self.c_proj(att_output)
        att_output = self.resid_dropout(att_output)

        return att_output


class GPT2MLP(nn.Module):
    """ 该类为一个完整的 FFN 结构 """

    def __init__(self, intermediate_size, config):
        super().__init__()

        self.embed_dim = config.hidden_size

        # 一般情况下 intermediate_size = 4 * self.embed_dim,所以这里是一个先升维,后降维的过程
        self.c_fc = nn.Linear(self.embed_dim, intermediate_size)
        self.c_proj = nn.Linear(intermediate_size, self.embed_dim)
        self.act = ... # TODO 这里初始化一个激活函数,比如 ReLU、GELU 等
        self.dropout = nn.Dropout(config.resid_pdrop)

    def forward(self, hidden_states, ):
        hidden_states = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.c_proj(hidden_states)
        hidden_states = self.dropout(hidden_states)
        return hidden_states


class GPT2Block(nn.Module):
    """ 该类为一个完整的 transformer 结构 """

    def __init__(self, config):
        super().__init__()

        hidden_size = config.hidden_size
        inner_dim = 4 * hidden_size

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.att = GTP2Attention(config)
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.mlp = GPT2MLP(inner_dim, config)

    def forward(self, hidden_states, attention_mask=None):
        # ------------------------------------------------------
        # multi-head attention 部分
        # ------------------------------------------------------
        # 保存一下 multi-head attention 层的输入值,等会用于残差连接
        residual = hidden_states
        # 前置(pred)layer norm
        hidden_states = self.ln_1(hidden_states)
        attn_outputs = self.att(hidden_states, attention_mask)
        # 残差连接
        hidden_states = attn_outputs + residual

        # ------------------------------------------------------
        # FFN 部分
        # ------------------------------------------------------
        # 保存一下 FFN 层的输入值,等会用于残差连接
        residual = hidden_states
        # 前置(pred)layer norm
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # 残差连接
        hidden_states = residual + feed_forward_hidden_states

        return hidden_states


class GPT2Model(nn.Module):
    """ 该类是堆叠 embedding 层以及多个 transformer 层 """

    def __init__(self, config):
        super().__init__()

        self.embed_dim = config.hidden_size

        self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
        self.drop = nn.Dropout(config.embd_pdrop)

        self.h = nn.ModuleList([GPT2Block(config) for _ in range(config.num_hidden_layers)])
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(self, input_ids, attention_mask, ):
        input_shape = input_ids.size()
        position_ids = torch.arange(input_shape[-1]).unsqueeze(0)

        # Embedding层,这里只相加了token embedding和position embeddding,忽略了token type embedding
        inputs_embed = self.wte(input_ids)
        position_embed = self.wpe(position_ids)
        # inputs_embed是[bs, seq_len, embed_dim],position_embed是[1, seq_len, embed_dim],这里相加时会做广播
        hidden_states = inputs_embed + position_embed
        hidden_states = self.drop(hidden_states)

        # 经过多个transformer层
        for block in self.h:
            hidden_states = block(hidden_states, attention_mask)

        # 由于tranformer中使用的是pre layer norm,所以最后还需要过一层layer norm
        hidden_states = self.ln_f(hidden_states)
        return hidden_states

Reference#