[隐藏左侧目录栏][显示左侧目录栏]

LSTM原理#

图1

1、输入和输出#

先把整个LSTMCell当成一个黑盒,只看其输出和输入。

输出:

  • h_t:表示当前t时刻的细胞的输出;如果使用LSTM做nlp中的序列标注任务,每个token的输出就是该值;
  • C_t:表示当前t时刻的细胞状态,字母 C 就是细胞cell的首字母,LSTM能够将之前时刻的信息传递到后面的时刻,就是使用该值做的信息传递;

输入:

  • h_{t-1}:表示上一时刻(t-1时刻)的细胞的输出;
  • C_{t-1}:表示上一时刻(t-1时刻)的细胞状态;
  • x_t:表示当前t时刻的输入;如果使用LSTM做nlp中的序列标注任务,每个token经过embedding之后就是这里的输入;

2、图中的两种颜色#

之前在学习LSTM的原理时,看过的所有LSTM资料,基本都会提到遗忘门、输入门、输出门这几个门控单元。这些门控单元的作用是控制多少信息通过,它们会输出一个0到1之间的概率,将这个概率乘到相应的信息路径上,就能控制该路径上的信息能够通过多少。如果门控单元输出的概率是1.0,就表示允许所有的信息通过;如果输出的是0.6,就表示仅允许60%的信息通过。

这样在LSTM的细胞中就有些路径是传递信息的,而另外一些则是门控单元。当时最费解的就是:哪些路径是用来传递信息的?哪些路径属于门控单元?在图1中使用颜色区分这两者:绿色的线表示用来传递信息的,橘色的线表示各种门控单元。比如 f_t 表示遗忘门, i_t 表示输入门, o_t 表示输出门。每一条橘色的线都会输出一个0到1之间的概率,将该概率乘到其对应的绿色的路径上,就能实现控制绿色路径上的信息通过多少的目的。

3、公式详情#

关于LSTM细胞内部的结构,依照图1,下面从右往左逐步分析。需要说明一下,图1只是简图,并不是所有的运算都在图上有体现,比如各种激活函数在图上就没有体现。

LSTM细胞中各部分的详细计算公式说明如下:

  • 当前t时刻细胞的输出 h_t ,它是由当前t时刻细胞的状态乘上遗忘门得到的,公式如下:

    \begin{equation}h_t = o_t * \text{tanh}(C_t)\end{equation}
  • 当前t时刻的细胞状态 C_t 是由两部分组成的:前一时刻(t-1时刻)的细胞状态、当前t时刻的输入。公式如下:

    \begin{equation}C_t = f_t * C_{t-1} + i_t * \tilde{C}_t\end{equation}

    在上述公式中:

    • C_{t-1} 表示前一时刻的细胞状态;f_t 是遗忘门,用于控制前一时刻的状态有多大比例能够通过;

    • \tilde{C}_t 表示仅由当前t时刻的输入计算出来的(该值中仅含有当前t时刻的输入信息,不包含任何之前时刻的历史信息);i_t 是输入门,用于控制当前t时刻的输入信息有多大比例能够通过;

  • 前一时刻(t-1时刻)的细胞状态 C_{t-1} 是由前一时刻传递过来的。

  • 当前t时刻的输入 \tilde{C}_t 是根据输入信息计算出来的,公式如下:

    \begin{equation}\tilde{C}_t=\text{tanh}(W_c \cdot [h_{t-1}, x_t] + b_c)\end{equation}

    在上述公式中:

    • [h_{t-1}, x_t] 表示将向量 h_{t-1}x_t 拼接起来;

    • W_c \cdot [h_{t-1}, x_t] + b_c 表示对向量 [h_{t-1}, x_t] 做一个线性变换;在模型上来说就是经过一个linear layer,其中 W_cb_c 是可学习的参数;

  • 三个门(遗忘门f_t、输入门i_t、输出门o_t)所对应的概率的计算公式都是相似的,如下所示:

    \begin{equation}f_t = \text{sigmoid}(W_f \cdot [h_{t-1}, x_t] + b_f)\end{equation}
    \begin{equation}i_t = \text{sigmoid}(W_i \cdot [h_{t-1}, x_t] + b_i)\end{equation}
    \begin{equation}o_t = \text{sigmoid}(W_o \cdot [h_{t-1}, x_t] + b_o)\end{equation}

    在上述公式中:

    • [h_{t-1}, x_t] 表示将向量 h_{t-1}x_t 拼接起来;

    • W \cdot [h_{t-1}, x_t] + b 表示对向量 [h_{t-1}, x_t] 做一个线性变换;在模型上来说就是经过一个linear layer,其中 W_fW_iW_ob_fb_ib_o 都是可学习的参数;

    • \text{sigmoid}(\cdot) 是激活函数,目的是将其范围压缩到 [0,1] 之间;

以上就是一个LSTM细胞内所有部分的计算公式。

4、实现代码#

import torch
import torch.nn as nn
import torch.nn.functional as F

class LSTMCell(nn.Module):

    def __init__(self):
        super().__init__()

        self.linear_f = nn.Linear()  # 遗忘门
        self.linear_i = nn.Linear()  # 输入门
        self.linear_c = nn.Linear()  # 当前输入对应的准内部状态
        self.linear_o = nn.Linear()  # 输出门

    def forward(self, inputs, h_t_1, c_t_1):
        f_t = F.sigmoid(self.linear_f(torch.cat([inputs, h_t_1], dim=-1)))
        i_t = F.sigmoid(self.linear_i(torch.cat([inputs, h_t_1], dim=-1)))
        _c_t = F.tanh(self.linear_c(torch.cat([inputs, h_t_1], dim=-1)))

        c_t = f_t * c_t_1 + i_t * _c_t

        o_t = F.sigmoid(self.linear_o(torch.cat([inputs, h_t_1], dim=-1)))
        h_t = o_t * F.tanh(c_t)

        return h_t, c_t

class LSTM(nn.Module):

    def __init__(self):
        super().__init__()