LSTM原理#
图1 |
1、输入和输出#
先把整个LSTMCell当成一个黑盒,只看其输出和输入。
输出:
- h_t:表示当前t时刻的细胞的输出;如果使用LSTM做nlp中的序列标注任务,每个token的输出就是该值;
- C_t:表示当前t时刻的细胞状态,字母 C 就是细胞cell的首字母,LSTM能够将之前时刻的信息传递到后面的时刻,就是使用该值做的信息传递;
输入:
- h_{t-1}:表示上一时刻(t-1时刻)的细胞的输出;
- C_{t-1}:表示上一时刻(t-1时刻)的细胞状态;
- x_t:表示当前t时刻的输入;如果使用LSTM做nlp中的序列标注任务,每个token经过embedding之后就是这里的输入;
2、图中的两种颜色#
之前在学习LSTM的原理时,看过的所有LSTM资料,基本都会提到遗忘门、输入门、输出门这几个门控单元。这些门控单元的作用是控制多少信息通过,它们会输出一个0到1之间的概率,将这个概率乘到相应的信息路径上,就能控制该路径上的信息能够通过多少。如果门控单元输出的概率是1.0,就表示允许所有的信息通过;如果输出的是0.6,就表示仅允许60%的信息通过。
这样在LSTM的细胞中就有些路径是传递信息的,而另外一些则是门控单元。当时最费解的就是:哪些路径是用来传递信息的?哪些路径属于门控单元?在图1中使用颜色区分这两者:绿色的线表示用来传递信息的,橘色的线表示各种门控单元。比如 f_t 表示遗忘门, i_t 表示输入门, o_t 表示输出门。每一条橘色的线都会输出一个0到1之间的概率,将该概率乘到其对应的绿色的路径上,就能实现控制绿色路径上的信息通过多少的目的。
3、公式详情#
关于LSTM细胞内部的结构,依照图1,下面从右往左逐步分析。需要说明一下,图1只是简图,并不是所有的运算都在图上有体现,比如各种激活函数在图上就没有体现。
LSTM细胞中各部分的详细计算公式说明如下:
-
当前t时刻细胞的输出 h_t ,它是由当前t时刻细胞的状态乘上遗忘门得到的,公式如下:
\begin{equation}h_t = o_t * \text{tanh}(C_t)\end{equation} -
当前t时刻的细胞状态 C_t 是由两部分组成的:前一时刻(t-1时刻)的细胞状态、当前t时刻的输入。公式如下:
\begin{equation}C_t = f_t * C_{t-1} + i_t * \tilde{C}_t\end{equation}在上述公式中:
-
C_{t-1} 表示前一时刻的细胞状态;f_t 是遗忘门,用于控制前一时刻的状态有多大比例能够通过;
-
\tilde{C}_t 表示仅由当前t时刻的输入计算出来的(该值中仅含有当前t时刻的输入信息,不包含任何之前时刻的历史信息);i_t 是输入门,用于控制当前t时刻的输入信息有多大比例能够通过;
-
-
前一时刻(t-1时刻)的细胞状态 C_{t-1} 是由前一时刻传递过来的。
-
当前t时刻的输入 \tilde{C}_t 是根据输入信息计算出来的,公式如下:
\begin{equation}\tilde{C}_t=\text{tanh}(W_c \cdot [h_{t-1}, x_t] + b_c)\end{equation}在上述公式中:
-
[h_{t-1}, x_t] 表示将向量 h_{t-1} 和 x_t 拼接起来;
-
W_c \cdot [h_{t-1}, x_t] + b_c 表示对向量 [h_{t-1}, x_t] 做一个线性变换;在模型上来说就是经过一个linear layer,其中 W_c、b_c 是可学习的参数;
-
-
三个门(遗忘门f_t、输入门i_t、输出门o_t)所对应的概率的计算公式都是相似的,如下所示:
\begin{equation}f_t = \text{sigmoid}(W_f \cdot [h_{t-1}, x_t] + b_f)\end{equation}\begin{equation}i_t = \text{sigmoid}(W_i \cdot [h_{t-1}, x_t] + b_i)\end{equation}\begin{equation}o_t = \text{sigmoid}(W_o \cdot [h_{t-1}, x_t] + b_o)\end{equation}在上述公式中:
-
[h_{t-1}, x_t] 表示将向量 h_{t-1} 和 x_t 拼接起来;
-
W \cdot [h_{t-1}, x_t] + b 表示对向量 [h_{t-1}, x_t] 做一个线性变换;在模型上来说就是经过一个linear layer,其中 W_f、W_i、W_o、b_f、b_i、b_o 都是可学习的参数;
-
\text{sigmoid}(\cdot) 是激活函数,目的是将其范围压缩到 [0,1] 之间;
-
以上就是一个LSTM细胞内所有部分的计算公式。
4、实现代码#
import torch
import torch.nn as nn
import torch.nn.functional as F
class LSTMCell(nn.Module):
def __init__(self):
super().__init__()
self.linear_f = nn.Linear() # 遗忘门
self.linear_i = nn.Linear() # 输入门
self.linear_c = nn.Linear() # 当前输入对应的准内部状态
self.linear_o = nn.Linear() # 输出门
def forward(self, inputs, h_t_1, c_t_1):
f_t = F.sigmoid(self.linear_f(torch.cat([inputs, h_t_1], dim=-1)))
i_t = F.sigmoid(self.linear_i(torch.cat([inputs, h_t_1], dim=-1)))
_c_t = F.tanh(self.linear_c(torch.cat([inputs, h_t_1], dim=-1)))
c_t = f_t * c_t_1 + i_t * _c_t
o_t = F.sigmoid(self.linear_o(torch.cat([inputs, h_t_1], dim=-1)))
h_t = o_t * F.tanh(c_t)
return h_t, c_t
class LSTM(nn.Module):
def __init__(self):
super().__init__()