GPT2 模型结构和实现代码#
GPT2 的模型结构相比于 GPT 是有着少许改变的,后续的 GPT3 等模型也是在 GPT2 的基础上改出来的,所以详细了解 GPT2 的整个模型结构的细节是很有必要的。
在huggingface的项目 transformers 中对 GPT2 有着完整的实现,不过该项目毕竟是一个庞大的工程项目,里面有着各种功能的代码,不利于阅读。下面的代码是从 transformers 中的 GPT2 的实现中仅把模型结构部分的代码摘取出来之后的结果。需要注意的是,下述代码仅用于学习理解,其在语法上有着不少的bug,是跑不通的;同时该代码不一定是最优雅的实现,而是尽量选择容易理解的实现方式。
模型中涉及的几个公式如下所示:
\begin{equation}\text{Attention}(x) = \text{softmax}(\frac{xW_q \cdot xW_k}{\sqrt{d_k}}) \cdot xW_v\end{equation}
\begin{equation}\text{FFN}(x) = \text{act}(xW_1)W_2\end{equation}
\begin{equation}\text{transformer}(x) = \text{residual}(\text{FFN}(\text{LN}(\text{residual}(\text{Attention}(\text{LN}(x))))))\end{equation}
上述公式中 \text{act}(\cdot) 表示激活函数。
下述代码中有四个类:
-
GTP2Attention
: 该类实现的功能是 multi-head attention 部分的功能; -
GPT2MLP
: 该类实现的功能是 FFN 部分的功能; -
GPT2Block
: 该类堆叠了 multi-head attention 和 FFN,也就是实现了 transformer 层; -
GPT2Model
: 该类是堆叠了 embedding 层和多个 tranformer 层,也就是一个完整的 decoder-only 的模型结构;
主要部分的代码都已经有了注释,比较容易理解:
import torch
from torch import nn
class GTP2Attention(nn.Module):
""" 该类为一个完整的 multi-head attention 结构 """
def __init__(self, config):
super().__init__()
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
self.c_attn = nn.Linear(self.embed_dim, 3 * self.embed_dim)
self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
def forward(self, hidden_states, attention_mask=None):
# hidden_states: [bs, seq_len, embed_dim]
# 将 W_q, W_k, W_v 三个矩阵合成一次矩阵乘法运算,乘完之后再分割开来
# query: [bs, seq_len, embed_dim]
# key: [bs, seq_len, embed_dim]
# value: [bs, seq_len, embed_dim]
query, key, value = self.c_attn(hidden_states).split(self.embed_dim, dim=2)
# [bs, seq_len, embed_dim] -> [bs, seq_len, num_heads, head_dim] -> [bs, num_heads, seq_len, head_dim]
new_size = query.size()[:-1] + [self.num_heads, self.head_dim]
query, key, value = query.view(new_size), key.view(new_size), value.view(new_size)
query, key, value = query.permute(0, 2, 1, 3), key.permute(0, 2, 1, 3), value.permute(0, 2, 1, 3)
# 注意力权重矩阵att_weights: [bs, num_heads, seq_len, seq_len]
att_weights = torch.matmul(query, key.transpose(-1, -2))
# 对注意力矩阵除上一个根号d_k,然后再做softmax
att_weights = att_weights / torch.full([], value.size()[-1] ** 0.5)
if attention_mask is not None:
att_weights = att_weights + attention_mask
att_weights = nn.functional.softmax(att_weights, dim=-1)
att_weights = self.attn_dropout(att_weights)
# att_output: [bs, num_heads, seq_len, head_dim] = [bs, num_heads, seq_len, seq_len] * [bs, num_heads, seq_len, head_dim]
att_output = torch.matmul(att_weights, value)
# 注意:这里不能够直接将 att_output 由 [bs, num_heads, seq_len, head_dim] 变为 [bs, seq_len, embed_dim]
# [bs, num_heads, seq_len, head_dim] -> [bs, seq_len, num_heads, head_dim] -> [bs, seq_len, embed_dim]
att_output = att_output.permute(0, 2, 1, 3)
att_output = att_output.view(att_output.size()[:-2] + [self.num_heads * self.head_dim, ])
#
att_output = self.c_proj(att_output)
att_output = self.resid_dropout(att_output)
return att_output
class GPT2MLP(nn.Module):
""" 该类为一个完整的 FFN 结构 """
def __init__(self, intermediate_size, config):
super().__init__()
self.embed_dim = config.hidden_size
# 一般情况下 intermediate_size = 4 * self.embed_dim,所以这里是一个先升维,后降维的过程
self.c_fc = nn.Linear(self.embed_dim, intermediate_size)
self.c_proj = nn.Linear(intermediate_size, self.embed_dim)
self.act = ... # TODO 这里初始化一个激活函数,比如 ReLU、GELU 等
self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, hidden_states, ):
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class GPT2Block(nn.Module):
""" 该类为一个完整的 transformer 结构 """
def __init__(self, config):
super().__init__()
hidden_size = config.hidden_size
inner_dim = 4 * hidden_size
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.att = GTP2Attention(config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, config)
def forward(self, hidden_states, attention_mask=None):
# ------------------------------------------------------
# multi-head attention 部分
# ------------------------------------------------------
# 保存一下 multi-head attention 层的输入值,等会用于残差连接
residual = hidden_states
# 前置(pred)layer norm
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.att(hidden_states, attention_mask)
# 残差连接
hidden_states = attn_outputs + residual
# ------------------------------------------------------
# FFN 部分
# ------------------------------------------------------
# 保存一下 FFN 层的输入值,等会用于残差连接
residual = hidden_states
# 前置(pred)layer norm
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# 残差连接
hidden_states = residual + feed_forward_hidden_states
return hidden_states
class GPT2Model(nn.Module):
""" 该类是堆叠 embedding 层以及多个 transformer 层 """
def __init__(self, config):
super().__init__()
self.embed_dim = config.hidden_size
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([GPT2Block(config) for _ in range(config.num_hidden_layers)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(self, input_ids, attention_mask, ):
input_shape = input_ids.size()
position_ids = torch.arange(input_shape[-1]).unsqueeze(0)
# Embedding层,这里只相加了token embedding和position embeddding,忽略了token type embedding
inputs_embed = self.wte(input_ids)
position_embed = self.wpe(position_ids)
# inputs_embed是[bs, seq_len, embed_dim],position_embed是[1, seq_len, embed_dim],这里相加时会做广播
hidden_states = inputs_embed + position_embed
hidden_states = self.drop(hidden_states)
# 经过多个transformer层
for block in self.h:
hidden_states = block(hidden_states, attention_mask)
# 由于tranformer中使用的是pre layer norm,所以最后还需要过一层layer norm
hidden_states = self.ln_f(hidden_states)
return hidden_states