初心者データサイエンティストの備忘録

調べたことは全部ここに書いて自分の辞書を作る

LSTMはなぜRNNより勾配消失を起こしにくいのか?:誤差逆伝播の更新式を比較して確認

はじめに

 株価の変動や文章のような系列データを扱う深層学習モデルに、RNN(Recurrent Neural Network)とLSTM(Long Short Term Memory)があります。RNNは、系列データを扱う基本的なモデルの一種ですが、学習の際に勾配消失が発生しやすく、層を多く重ねられないという課題があります。LSTMはその課題を緩和し、層を多く重ねることを可能にしたモデルです。これにより、系列データに関する予測問題などの精度を向上させることができました。

 「なぜLSTMは勾配消失を起こしにくいのか?」この疑問に数式を用いて解説している日本語の記事は、私が調べた限りそれほど多くないようです。そこで、本記事ではRNNとLSTMの誤差逆伝播の更新式を比較し、LSTMが勾配消失を起こしにくい理由を明確にします。LSTMが勾配消失を起こしにくい理由を深く理解することで、今後提案される新しい深層学習モデルの勾配消失についてもよりよく理解できると私は考えています。

 本記事の構成は以下の通りです。まず、系列データの定義と例を簡単に紹介します。その後、RNNとLSTMそれぞれの構造を紹介し、それらの誤差逆伝播の更新式を導出します。そして最後に、両者の更新式を比較し、LSTMがRNNよりも勾配消失を起こしにくい理由を説明します。

系列データ

 この章では、系列データを簡単に紹介します。深層学習では、系列データを以下のように定義しています。

個々の要素が順序付きの集まり


\boldsymbol{x}^1, \boldsymbol{x}^2, \boldsymbol{x}^3, \dots, \boldsymbol{x}^T
として与えられるデータ

 系列データの例としては、株価の値動きや文章などが考えられます。このようなデータでは、各要素の順序が重要となります。ここで、「 t」という記号は系列の順番を表し、時系列データでなくても「時刻」と呼ばれることがあります。この記事でも、 tを時刻と呼びます。

 以下に系列データに関する問題例を2つ挙げます。

1.株価予測
 一つ目の例は株価予測です。例えば、2024/7/29から8/2までの日経平均株価始値を使い、8/3の始値を予測する問題を考えます。このとき、以下が説明変数です。

時刻( t 始値 \boldsymbol{x} ^ t
1 38,139.12
2 38,241.35
3 38,140.77
4 38,781.56
5 37,444.17

です。 t=1が7/29を、 t=5が8/2を表しています。また、目的変数は8/3の始値です。

 この問題は、複数の時刻の説明変数を使い、一つの目的変数を予測する問題です(図1)。

図1:株価予測の説明変数と目的変数

2.品詞分類
 二つ目の例は文章の品詞分類です。例えば、"I have a pen"という文章の各単語について、その品詞を分類する問題を考えます。この場合、以下が説明変数となります。

時刻( t 単語( \boldsymbol{x} ^ t
1 I
2 have
3 a
4 pen

です。また、目的変数は以下の通りです。

時刻( t 品詞( \boldsymbol{d} ^ t
1 名詞
2 動詞
3 冠詞
4 名詞

 この問題は、複数の時刻の説明変数を使い、各時刻ごとに目的変数を予測する問題です(図2)。

図2:品詞分類の説明変数と目的変数

 系列データを使った問題の例を2つ挙げました。これらの問題はいずれも、目的変数がその時刻の説明変数だけでなく、前の時刻の説明変数にも依存しています。そのため、通常の深層学習モデルでは十分に対応できません。

 このような要件に対応するため、RNNとLSTMが開発されました。

RNNとLSTM

RNN

RNNの構造

 まずは、系列データを扱う基本的なモデルRNNを紹介します。

 RNNは、以下のようにデータを処理します

  • 各時刻の説明変数 \boldsymbol{x}^tから内部状態 \boldsymbol{u}^tを計算します
  • 内部状態 \boldsymbol{u}^tに活性化関数を適用し、 \boldsymbol{z}^tを計算します
  • 内部状態 \boldsymbol{z}^tから出力 \boldsymbol{v}^tを計算します
  • 出力 \boldsymbol{v}^tから目的変数 \boldsymbol{y}^tを計算します

 特徴的なのは、RNNが前の時刻の内部状態を次の時刻に引き継ぐことです。この仕組みにより、 p(\boldsymbol{y}|\boldsymbol{x}^t, \boldsymbol{x}^{t-1}, \dots, \boldsymbol{x}^1) のように、現在の出力が過去の入力にも依存するモデルを構築できます。

 以下の図3は、RNNの構造を視覚的に表したものです。

図3:RNNの構造

 次に、順伝播を説明するために、各変数やパラメータを以下のように定義します。

  • 各変数の定義
    • 入力: \boldsymbol{x} ^ {t} = (x _ i ^ t)
    • 内部状態:
      •  \boldsymbol{u} ^ {t} = (u _ j ^ t)
      •  \boldsymbol{z} ^ {t} = (z _ j ^ t)
    • 出力層への入力: \boldsymbol{v} ^ {t} = (v _ k ^ t)
    • 出力: \boldsymbol{y} ^ {t} = (y _ k ^ t)
  • 各パラメータの定義
    • 入力 \boldsymbol{x} ^ {t}と内部状態 \boldsymbol{u} ^ {t}間のパラメータ: W ^ {\rm in} = (w _ {ji} ^ {\rm in})
    • 内部状態 \boldsymbol{z} ^ {t-1} \boldsymbol{u} ^ {t}間のパラメータ: W = (w _ {jj ^ \prime})
    • 内部状態 \boldsymbol{z} ^ {t}と出力層 \boldsymbol{v} ^ {t}間のパラメータ: W ^ {\rm out} = (w _ {kj} ^ {\rm out})

 推定対象のパラメータは W, W ^ {\rm in}, W ^ {\rm out}です。

 このとき、順伝播の式を成分表記すると、


\begin{eqnarray}
u _ j ^ {t+1} &=& \sum_{j^\prime} w_{jj^\prime} z_{j^\prime}^{t} + \sum_i w_{ji}^{\rm in} x_i^{t+1} \tag{1} \\
z_j^{t+1} &=& f(u _ j ^ {t+1} )  \tag{2} \\
v_k^{t+1} &=& \sum_j w_{kj}^{\rm out} z_j^{t+1}  \tag{3} \\
y_k^{t+1} &=& f^{\rm out}(v_k^{t+1})  \tag{4} 
\end{eqnarray}

となります。ただし、 fは活性化関数です。双曲線正接関数 {\rm tanh}などが使われます。また、 f ^ {\rm out}は出力層の活性化関数です。問題設定に応じてソフトマックス関数などが使われます。

 (1)~(4)式を行列表記すると以下のようになります。


\begin{eqnarray}
\boldsymbol{u}^{t+1} &=& W \boldsymbol{z}^{t} + W^{\rm in}  \boldsymbol{x}^{t+1} \tag{5} \\
\boldsymbol{z}^{t+1} &=& \boldsymbol{f}(\boldsymbol{u}^{t+1}) \tag{6} \\
\boldsymbol{v}^{t+1} &=& W^{\rm out} \boldsymbol{z}^{t+1} \tag{7} \\
\boldsymbol{y}^{t+1} &=& \boldsymbol{f}^{\rm out} (\boldsymbol{v}^{t+1}) \tag{8} \\
\end{eqnarray}

 ただし、 \boldsymbol{a}=(a _ 1, \dots, a _ n)に対して \boldsymbol{f}(\boldsymbol{a}) = (f(a _ 1), \dots, f(a _ n)), \ \boldsymbol{f}^{\rm out}(\boldsymbol{a}) = (f^{\rm out}(a _ 1), \dots, f^{\rm out}(a _ n))です。

 RNNの動作ポイントは以下の通りです。

 (5)、(6)式より内部状態  \boldsymbol{z} ^ {t+1} は、前の時刻の内部状態  \boldsymbol{z}^{t} に依存しています。この仕組みにより、過去の説明変数  \boldsymbol{x}^1, \boldsymbol{x}^2, \dots, \boldsymbol{x}^t の情報が現在の出力  \boldsymbol{y}^t に反映されます。

 RNNはシンプルな構造を持つため、誤差逆伝播によるパラメータの更新式を、LSTMに比べれば簡単に導出できます。次の節では、誤差逆伝播の更新式を導出する上で必要な考え方を説明し、実際に導出します。

RNNの誤差逆伝播

 この節では、RNNの誤差逆伝播を用いたパラメータの更新式を導出します。ただし、損失関数  E に対する全てのパラメータの微分を求めると、計算が煩雑になりすぎてしまいます。そのため、ここでは計算の要となる  \dfrac{\partial E}{\partial \boldsymbol{u}^t} の導出に焦点を絞ります。

 まず、導出に必要な「ベクトルのベクトル微分」を以下のように定義します。

定義:ベクトルのベクトル微分
 \boldsymbol{z} = (z_1, \dots, z_I) \boldsymbol{u} = (u_1, \dots, u_J)をベクトルとする。このとき、ベクトルによるベクトルの微分 \dfrac{\partial \boldsymbol{z}}{\partial \boldsymbol{u}}を以下のように定義する。

\begin{eqnarray}
\dfrac{\partial \boldsymbol{z}}{\partial \boldsymbol{u}} = 
\left(
\begin{array}{cccc}
   \dfrac{\partial z_1}{\partial u_1} & \dfrac{\partial z_1}{\partial u_2} & \cdots &\dfrac{\partial z_1}{\partial u_J} \\
   \dfrac{\partial z_2}{\partial u_1} & \dfrac{\partial z_2}{\partial u_2} & \cdots & \dfrac{\partial z_2}{\partial u_J} \\
   \vdots & \vdots & \ddots & \vdots \\
   \dfrac{\partial z_I}{\partial u_1} & \dfrac{\partial z_I}{\partial u_2} & \cdots & \dfrac{\partial z_I}{\partial u_J} \\
\end{array}
\right)
\end{eqnarray}

 なお、導出の中で、ベクトルによるスカラー微分 \left( \dfrac{\partial E}{\partial \boldsymbol{u} ^ t} {\rm など}\right)を使うことがあります。この場合は、スカラーを成分が一つだけのベクトルとみなして、「定義:ベクトルのベクトル微分」を適用します。

 次に、導出に必要な合成関数の微分法についての命題を紹介します。なお、証明は「補足:誤差逆伝播の更新式導出に用いた命題の証明」で行います。

命題:合成関数の微分
  w _ {ij} Eスカラー \boldsymbol{u} ^ t \ (t=2, \dots, T)をベクトルとする。また、 w _ {ij} \to (\boldsymbol{u} ^ 2, \dots, \boldsymbol{u} ^ T) (\boldsymbol{u} ^ 2, \dots, \boldsymbol{u} ^ T) \to Eの二つの写像を考える。このとき、 w _ {ij}による E微分 \dfrac{\partial E}{\partial w _ {ij}}は以下のようになる。

\begin{eqnarray}
\dfrac{\partial E}{\partial w _ {ij}} = \displaystyle{\sum_{t = 2} ^ {T}} \dfrac{\partial E}{\partial \boldsymbol{u}^t} \dfrac{\partial \boldsymbol{u}^t}{\partial w_{ij}}
\end{eqnarray}

 これで準備が整いました。以下では、 \dfrac{\partial E}{\partial \boldsymbol{u} ^ t}を具体的に導出します。

 順伝播の式 (6) より、 \boldsymbol{u} ^ tから \boldsymbol{z} ^ tを経由して損失関数 Eが計算されます。つまり、写像 \boldsymbol{u} ^ t \to \boldsymbol{z} ^ t \to Eです。したがって、合成関数の微分法より、


\begin{eqnarray}
\dfrac{\partial E}{\partial \boldsymbol{u} ^ t} = \dfrac{\partial E}{\partial \boldsymbol{z} ^ t} \dfrac{\partial \boldsymbol{z} ^ t}{\partial \boldsymbol{u} ^ t} 
\end{eqnarray}

となります。

 同様に、順伝播の(5)と(7)式より、 \boldsymbol{z} ^ tから \boldsymbol{u} ^ {t+1} \boldsymbol{v} ^ {t}を経由して損失関数 Eが計算されます。したがって、合成関数の微分法より、


\begin{eqnarray}
\dfrac{\partial E}{\partial \boldsymbol{z} ^ t} = \dfrac{\partial E}{\partial \boldsymbol{u} ^ {t+1}} \dfrac{\partial \boldsymbol{u} ^ {t+1}}{\partial \boldsymbol{z} ^ t} + \dfrac{\partial E}{\partial \boldsymbol{v} ^ t} \dfrac{\partial \boldsymbol{v} ^ {t}}{\partial \boldsymbol{z} ^ t} 
\end{eqnarray}

となります。以上より、


\begin{eqnarray}
\dfrac{\partial E}{\partial \boldsymbol{u} ^ t} = \left(\dfrac{\partial E}{\partial \boldsymbol{u} ^ {t+1}} \dfrac{\partial \boldsymbol{u} ^ {t+1}}{\partial \boldsymbol{z} ^ t} + \dfrac{\partial E}{\partial \boldsymbol{v} ^ t} \dfrac{\partial \boldsymbol{v} ^ {t}}{\partial \boldsymbol{z} ^ t} 
\right) \dfrac{\partial \boldsymbol{z} ^ t}{\partial \boldsymbol{u} ^ t} \tag{9}
\end{eqnarray}

となります。

 次に、RNNの更新式を、後述のLSTMの更新式と比較しやすいように、(9)式を成分表記します。

 まず、 \dfrac{\partial E}{\partial \boldsymbol{u} ^ {t+1}} \dfrac{\partial \boldsymbol{u}^{t+1}}{\partial \boldsymbol{z} ^ {t}}については、「定義:ベクトルのベクトル微分」と(1)式より、


\begin{eqnarray}
\dfrac{\partial E}{\partial \boldsymbol{u} ^ {t+1}} &=& \left( \dfrac{\partial E}{\partial u_1^{t+1}}, \dots, \dfrac{\partial E}{\partial u_j^{t+1}}, \dots, \dfrac{\partial E}{\partial u_J^{t+1}} \right) \\
\dfrac{\partial \boldsymbol{u}^{t+1}}{\partial \boldsymbol{z} ^ {t}} &=&
\left(
\begin{array}{cccc}
\dfrac{\partial u_1^{t+1}}{\partial z_1^t} & \dfrac{\partial u_1^{t+1}}{\partial z_2^t} & \cdots & \dfrac{\partial u_1^{t+1}}{\partial z_J^t} \\
\dfrac{\partial u_2^{t+1}}{\partial z_1^t} & \dfrac{\partial u_2^{t+1}}{\partial z_2^t} & \cdots & \dfrac{\partial u_2^{t+1}}{\partial z_J^t} \\
\vdots & \vdots & \ddots & \vdots \\
\dfrac{\partial u_J^{t+1}}{\partial z_1^t} & \dfrac{\partial u_J^{t+1}}{\partial z_2^t} & \cdots & \dfrac{\partial u_J^{t+1}}{\partial z_J^t} \\
\end{array} \right)
&=& W
\end{eqnarray}

となります。したがって、


\begin{eqnarray}
\dfrac{\partial E}{\partial \boldsymbol{u} ^ {t+1}} \dfrac{\partial \boldsymbol{u}^{t+1}}{\partial \boldsymbol{z} ^ {t}} &=& \left( \dfrac{\partial E}{\partial u_1^{t+1}}, \dots, \dfrac{\partial E}{\partial u_j^{t+1}}, \dots, \dfrac{\partial E}{\partial u_J^{t+1}} \right)
\left( \begin{array}{cccc}
w_{11} & w_{12} & \cdots & w_{1J} \\
w_{21} & w_{22} & \cdots & w_{2J} \\
\vdots & \vdots & \ddots & \vdots \\
w_{J1} & w_{J2} & \cdots & w_{JJ} \\
\end{array} \right) \\
&=& \left(
\displaystyle{\sum_{j=1}^{J}} \dfrac{\partial E}{\partial u_{j}^{t+1}} w_{j 1}, \dots, \displaystyle{\sum_{j=1}^{J}} \dfrac{\partial E}{\partial u_{j}^{t+1}} w_{j j^\prime}, \dots, \displaystyle{\sum_{j=1}^{J}} \dfrac{\partial E}{\partial u_{j}^{t+1}} w_{j J}
\right) \tag{10}
\end{eqnarray}

となります。さらに、


\begin{eqnarray}
\boldsymbol{d} = (d_1, \dots, d_{j^\prime}, \dots, d_J)= \dfrac{\partial E}{\partial \boldsymbol{v} ^ t} \dfrac{\partial \boldsymbol{v} ^ {t}}{\partial \boldsymbol{z} ^ t} \tag{11}
\end{eqnarray}

と定義します。また、 \boldsymbol{f}^\prime(\boldsymbol{u} ^ t) = (f^\prime(u _ 1 ^ t), \dots, f^\prime(u _ J ^ t))とすると(2)式より、


\begin{eqnarray}
\dfrac{\partial \boldsymbol{z} ^ t}{\partial  \boldsymbol{u} ^ t} &=& 
\left(
\begin{array}{cccc}
   \dfrac{\partial z_1^t}{\partial u_1^t} & \dfrac{\partial z_1^t}{\partial u_2^t} & \cdots &\dfrac{\partial z_1^t}{\partial u_J^t} \\
   \dfrac{\partial z_2^t}{\partial u_1^t} & \dfrac{\partial z_2^t}{\partial u_2^t} & \cdots & \dfrac{\partial z_2^t}{\partial u_J^t} \\
   \vdots & \vdots & \ddots & \vdots \\
   \dfrac{\partial z_I^t}{\partial u_1^t} & \dfrac{\partial z_I^t}{\partial u_2^t} & \cdots & \dfrac{\partial z_I^t}{\partial u_J^t} \\
\end{array}
\right) \\
&=& 
\left(
\begin{array}{cccc}
   f^\prime(u_1^t) & 0 & \cdots &0 \\
   0 & f^\prime(u_2^t) & \cdots & 0 \\
   \vdots & \vdots & \ddots & \vdots \\
   0 & 0 & \cdots & f^\prime(u_J^t) \\
\end{array}
\right) \tag{12}
\end{eqnarray}

です。以上より、(9)~(12)式を用いると、 \dfrac{\partial E}{\partial \boldsymbol{u} ^ t}の第 j ^ \prime成分 \dfrac{\partial E}{\partial u _ {j ^ \prime} ^ {t}}は、


\begin{eqnarray}
\dfrac{\partial E}{\partial u _ {j^\prime} ^ {t}} = \left( \displaystyle{\sum_{j=1}^{J}} \dfrac{\partial E}{\partial u_{j}^{t+1}} w_{j j^\prime} + d_{j^\prime}\right) f^\prime(u_{j^\prime}^t) \tag{13}
\end{eqnarray}

となります。

 以上で、RNNの誤差逆伝播の更新式を導出できました。ここで導出した(13)式は、「RNNとLSTMの更新式の比較」の章でLSTMの更新式と比較します。

LSTM

LSTMの構造

 次に、RNNの勾配消失を抑えたモデルLSTMを紹介します。

 LSTMはRNNに「入力ゲート」、「忘却ゲート」、「出力ゲート」と呼ばれる3つのゲートを加えたモデルです。図4は、LSTMの構造を視覚的に表したものです(図を単純にするため、パラメータは省略しています)。

図4:LSTMの構造

 RNNと同様に順伝播の式を説明します。RNNの順伝播の式を説明する際に定義した各変数やパラメータに加え、LSTMで新しく使われる各変数とパラメータを以下のように定義します。

  • 各変数の定義
    • 内部状態: \boldsymbol{s} ^ {t} = (s _ j ^ t)
    • 入力ゲート: \boldsymbol{g} ^ {I, t} = (g _ j ^ {I, t})
    • 忘却ゲート: \boldsymbol{g} ^ {F, t} = (g _ j ^ {F, t})
    • 出力ゲート: \boldsymbol{g} ^ {O, t} = (g _ j ^ {O, t})
  • 各パラメータの定義
    • 入力 \boldsymbol{x} ^ {t}と入力ゲート \boldsymbol{g} ^ {I, t}間のパラメータ: W ^ {I, {\rm in}} = (w _ {ji} ^ {I, {\rm in}})
    • 内部状態 \boldsymbol{z} ^ {t-1}と入力ゲート \boldsymbol{g} ^ {I, t}間のパラメータ: W ^ I = (w _ {jj ^ \prime} ^ {I})
    • 内部状態 \boldsymbol{s} ^ {t-1}と入力ゲート \boldsymbol{g} ^ {I, t}間のパラメータ: \boldsymbol{w} ^ {I} = (w _  j ^ I)
    • 入力 \boldsymbol{x} ^ {t}と忘却ゲート \boldsymbol{g} ^ {F, t}間のパラメータ: W ^ {F, {\rm in}} = (w_{ji} ^ {F, {\rm in}})
    • 内部状態 \boldsymbol{z} ^ {t-1}と忘却ゲート \boldsymbol{g} ^ {F, t}間のパラメータ: W ^ F = (w _ {jj ^ \prime} ^ {F})
    • 内部状態 \boldsymbol{s} ^ {t-1}と忘却ゲート \boldsymbol{g} ^ {F, t}間のパラメータ: \boldsymbol{w} ^ {F} = (w _  j ^ F)
    • 入力 \boldsymbol{x} ^ {t}と出力ゲート \boldsymbol{g} ^ {O, t}間のパラメータ: W ^ {O, {\rm in}} = (w_{ji} ^ {O, {\rm in}})
    • 内部状態 \boldsymbol{z} ^ {t-1}と忘却ゲート \boldsymbol{g} ^ {O, t}間のパラメータ: W ^ O = (w _ {jj ^ \prime} ^ {O})
    • 内部状態 \boldsymbol{s} ^ {t}と忘却ゲート \boldsymbol{g} ^ {O, t}間のパラメータ: \boldsymbol{w} ^ {O} = (w _  j ^ O)

 推定対象のパラメータはRNNで用いた W, W ^ {\rm in}, W ^ {\rm out}に加え、 W ^ {I, {\rm in}}, W ^ I, \boldsymbol{w} ^ {I}, W ^ {F, {\rm in}}, W ^ F, \boldsymbol{w} ^ {F}, W ^ {O, {\rm in}}, W ^ O, \boldsymbol{w} ^ {O}です。

 このとき、順伝播の式を成分表記すると、


\begin{eqnarray}
u_j^t &=& \sum_i w_{ji}^{\rm in} x_i^{t} + \sum_{j^\prime} w_{jj^\prime} z_{j^\prime}^{t-1} \tag{14} \\
s_j^t &=& g_j^{F, t} s_j^{t-1} + g_j^{I, t} f(u_j^t) \tag{15} \\
z_j^t &=& g_j^{O, t} f(s_j^t) \tag{16} \\
v_k^t &=& \sum_j w_{kj}^{\rm out} z_j^t \tag{17} \\
y_k^t &=& f^{\rm out}(v_k^t) \tag{18} \\
g_j^{I, t} &=& \sigma\left(\sum_i w_{ji}^{I, {\rm in}} x_i^{t} + \sum_{j^\prime} w_{jj^\prime}^{I} z_{j^\prime}^{t-1} + w_j^I s_j^{t-1} \right) \tag{19} \\
g_j^{F, t} &=& \sigma\left(\sum_i w_{ji}^{F, {\rm in}} x_i^{t} + \sum_{j^\prime} w_{jj^\prime}^{F} z_{j^\prime}^{t-1} + w_j^F s_j^{t-1} \right) \tag{20} \\
g_j^{O, t} &=& \sigma\left(\sum_i w_{ji}^{O, {\rm in}} x_i^{t} + \sum_{j^\prime} w_{jj^\prime}^{O} z_{j^\prime}^{t-1} + w_j^O s_j^{t} \right) \tag{21} \\
\end{eqnarray}

となります。ただし、 \sigmaシグモイド関数です。

 (14)~(21)式を行列表記すると以下のようになります。


\begin{eqnarray}
\boldsymbol{u}^t &=& W^{\rm in} \boldsymbol{x}^t + W \boldsymbol{z}^{t-1} \tag{22} \\
\boldsymbol{s}^t &=& \boldsymbol{g}^{F, t} \odot \boldsymbol{s}^{t-1} + \boldsymbol{g}^{I, t} \odot f\left(\boldsymbol{u}^{t}\right) \tag{23} \\
\boldsymbol{z}^t &=& \boldsymbol{g}^{O, t} \odot f\left(\boldsymbol{s}^{t}\right) \tag{24} \\
\boldsymbol{v}^t &=& W^{\rm out} \boldsymbol{z}^t \tag{25} \\
\boldsymbol{y}^t &=& f^{\rm out}\left(\boldsymbol{v}^t\right) \tag{26} \\
\boldsymbol{g}^{I, t} &=& \boldsymbol{\sigma} \left( W^{I, {\rm in}}\boldsymbol{x}^t+ W^I \boldsymbol{z}^{t-1} + \boldsymbol{w}^I \odot \boldsymbol{s}^{t-1} \right) \tag{27} \\
\boldsymbol{g}^{F, t} &=& \boldsymbol{\sigma} \left( W^{F, {\rm in}}\boldsymbol{x}^t+ W^F \boldsymbol{z}^{t-1} + \boldsymbol{w}^F \odot \boldsymbol{s}^{t-1} \right) \tag{28} \\
\boldsymbol{g}^{O, t} &=& \boldsymbol{\sigma} \left( W^{O, {\rm in}}\boldsymbol{x}^t+ W^O \boldsymbol{z}^{t-1} + \boldsymbol{w}^O \odot \boldsymbol{s}^{t} \right) \tag{29} \\
\end{eqnarray}

 ただし、 \odotはベクトルの成分積です。つまり、 \boldsymbol{a} = (a _ 1, a _ 2, \dots, a _ n),\ \boldsymbol{b} = (b _ 1, b _ 2, \dots, b _ n)に対して \boldsymbol{a} \odot \boldsymbol{b} = (a _ 1 b _ 1, a _ 2 b _ 2, \dots, a _ nb _ n)です。また、 \boldsymbol{\sigma}(\boldsymbol{a}) = (\sigma(a _ 1), \sigma(a _ 2), \dots, \sigma(a _ n))です。

 LSTMはRNNに「入力ゲート」、「忘却ゲート」、「出力ゲート」を加えたモデルです。これにより、LSTMはRNNよりもデータの流れを細かく制御できます。そのため、LSTMはRNNより高い表現力を持っています。一方で、LSTMは推定対象のパラメータが多いので、学習に時間がかかるという問題もあります。

 次の節では、RNNと同様にLSTMの誤差逆伝播の更新式を導出します。

LSTMの誤差逆伝播

 この節では、LSTMの誤差逆伝播の更新式を導出します。RNNと同様に、 \dfrac{\partial E}{\partial \boldsymbol{u} ^ t}だけを導出します。

 順伝播の(23)式より、 \boldsymbol{u} ^ tから \boldsymbol{s} ^ tを経由して損失関数 Eが計算されます。つまり、写像 \boldsymbol{u} ^ t \to \boldsymbol{s} ^ t \to Eです。したがって、合成関数の微分法より、


\begin{eqnarray}
\dfrac{\partial E}{\partial \boldsymbol{u} ^ t} = \dfrac{\partial E}{\partial \boldsymbol{s} ^ t} \dfrac{\partial \boldsymbol{s} ^ t}{\partial \boldsymbol{u} ^ t}  \tag{30}
\end{eqnarray}

となります。

 同様に、順伝播の(28)と(23)式より \boldsymbol{s} ^ tから \boldsymbol{s} ^ {t+1}を経由して、順伝播の(29)と(24)式より \boldsymbol{s} ^ tから \boldsymbol{z} ^ {t}を経由して損失関数 Eが計算されます。したがって、合成関数の微分法より、


\begin{eqnarray}
\dfrac{\partial E}{\partial \boldsymbol{s} ^ t} = \dfrac{\partial E}{\partial \boldsymbol{s} ^ {t+1}} \dfrac{\partial \boldsymbol{s} ^ {t+1}}{\partial \boldsymbol{s} ^ t} + \dfrac{\partial E}{\partial \boldsymbol{z} ^ t} \dfrac{\partial \boldsymbol{z} ^ {t}}{\partial \boldsymbol{s} ^ t} \tag{31}
\end{eqnarray}

となります。さらに順伝播の(22)式より \boldsymbol{z} ^ tから \boldsymbol{u} ^ {t+1}を経由して、(25)式より \boldsymbol{z} ^ tから \boldsymbol{v} ^ {t}を経由して、(27)~(29)式より \boldsymbol{z} ^ tから \boldsymbol{g} ^ {I, t+1} \boldsymbol{g} ^ {F, t+1} \boldsymbol{g} ^ {O, t+1}を経由して損失関数 Eが計算されます。したがって、


\begin{eqnarray}
\dfrac{\partial E}{\partial \boldsymbol{z}^t} = \dfrac{\partial E}{\partial \boldsymbol{u}^{t+1}} \dfrac{\partial  \boldsymbol{u}^{t+1}}{\partial \boldsymbol{z}^t} + \dfrac{\partial E}{\partial \boldsymbol{v}^t} \dfrac{\partial  \boldsymbol{v}^t}{\partial \boldsymbol{z}^t} + \displaystyle{\sum_{*= I, F, O}}\dfrac{\partial E}{\partial \boldsymbol{g}^{*, t+1}} \dfrac{\partial \boldsymbol{g}^{*, t+1}}{\partial \boldsymbol{z}^{t}} \tag{32}
\end{eqnarray}

となります。以上の(30)~(32)式より、


\begin{eqnarray}
\dfrac{\partial E}{\partial \boldsymbol{u} ^ t} &=& \left(\dfrac{\partial E}{\partial \boldsymbol{s} ^ {t+1}} \dfrac{\partial  \boldsymbol{s}^{t+1}}{\partial \boldsymbol{s}^t} \right. \\
&+& \left. \left(\dfrac{\partial E}{\partial \boldsymbol{u}^{t+1}} \dfrac{\partial  \boldsymbol{u}^{t+1}}{\partial \boldsymbol{z}^t} + \dfrac{\partial E}{\partial \boldsymbol{v}^t} \dfrac{\partial  \boldsymbol{v}^t}{\partial \boldsymbol{z}^t} + \displaystyle{\sum_{*= I, F, O}}\dfrac{\partial E}{\partial \boldsymbol{g}^{*, t+1}} \dfrac{\partial \boldsymbol{g}^{*, t+1}}{\partial \boldsymbol{z}^{t}} \right) \dfrac{\partial \boldsymbol{z}^t}{\partial \boldsymbol{s} ^ t}\right) \dfrac{\partial \boldsymbol{s} ^ t}{\partial \boldsymbol{u} ^ t} \tag{33}
\end{eqnarray}

となります。

 次に、前述のRNNの更新式と比較しやすいように、(33)式を成分表記します。

 まず、


\begin{eqnarray}
\boldsymbol{\alpha} &=& (\alpha_1, \dots, \alpha_J) = \dfrac{\partial E}{\partial \boldsymbol{s} ^ {t+1}} \dfrac{\partial  \boldsymbol{s}^{t+1}}{\partial \boldsymbol{s}^t} \tag{34} \\
\boldsymbol{\beta} &=& (\beta_1, \dots, \beta_J) = \dfrac{\partial E}{\partial \boldsymbol{v} ^ {t}} \dfrac{\partial  \boldsymbol{v}^{t}}{\partial \boldsymbol{z}^t} \tag{35}  \\
\boldsymbol{\gamma} &=& (\gamma_1, \dots, \gamma_J) = \displaystyle{\sum_{*= I, F, O}}\dfrac{\partial E}{\partial \boldsymbol{g}^{*, t+1}} \dfrac{\partial \boldsymbol{g}^{*, t+1}}{\partial \boldsymbol{z}^{t}}  \tag{36} \\
\boldsymbol{\eta} &=& \left(
\begin{array}{cccc}
\eta_{11} & \eta_{12} & \cdots & \eta_{1J} \\
\eta_{21} & \eta_{22} & \cdots & \eta_{2J} \\
\vdots & \vdots & \ddots & \vdots \\
\eta_{J1} & \eta_{J2} & \cdots & \eta_{JJ} \\
\end{array}
\right)
= \dfrac{\partial \boldsymbol{z}^t}{\partial \boldsymbol{s}^t} \tag{37}
\end{eqnarray}

と定義します。

 さらに、(10)式と同様に


\begin{eqnarray}
\dfrac{\partial E}{\partial \boldsymbol{u} ^ {t+1}} \dfrac{\partial \boldsymbol{u}^{t+1}}{\partial \boldsymbol{z} ^ {t}} &=& \left(
\displaystyle{\sum_{j=1}^{J}} \dfrac{\partial E}{\partial u_{j}^{t+1}} w_{j 1}, \dots, \displaystyle{\sum_{j=1}^{J}} \dfrac{\partial E}{\partial u_{j}^{t+1}} w_{j j^\prime}, \dots, \displaystyle{\sum_{j=1}^{J}} \dfrac{\partial E}{\partial u_{j}^{t+1}} w_{j J}
\right)  \tag{38} 
\end{eqnarray}

です。また、(12)式と同様に


\begin{eqnarray}
\dfrac{\partial \boldsymbol{s} ^ t}{\partial  \boldsymbol{u} ^ t} = \left(
\begin{array}{cccc}
   g_1^{I, t} f^\prime(u_1^t) & 0 & \cdots &0 \\
   0 & g_2^{I, t} f^\prime(u_2^t) & \cdots & 0 \\
   \vdots & \vdots & \ddots & \vdots \\
   0 & 0 & \cdots & g_J^{I, t} f^\prime(u_J^t) \\
\end{array}
\right)  \tag{39} 
\end{eqnarray}

です。以上より、(33)~(39)式を用いると、 \dfrac{\partial E}{\partial \boldsymbol{u} ^ t}の第 j ^ \prime成分 \dfrac{\partial E}{\partial u _ {j ^ \prime} ^ t}は、


\begin{eqnarray}
\dfrac{\partial E}{\partial u _ {j ^ \prime} ^ t} &=& \left(\alpha_{j^\prime} + \displaystyle{\sum_{k=1}^{J}}\left( \displaystyle{\sum_{j=1}^{J}}\dfrac{\partial E}{\partial u_{j}^{t+1}}w_{jk} + \beta_{k} + \gamma_{k} \right) \eta_{k j^\prime} \right) \\
& \times &g_{j^\prime}^{I, t}f^{\prime}(u_{j^\prime}^t) \tag{40}
\end{eqnarray}

となります。

 以上で、LSTMの誤差逆伝播の更新式を導出できました。次の章「RNNとLSTMの更新式の比較」では、ここで導出した(40)式をRNNの更新式((13)式)と比較し、LSTMがRNNよりも勾配消失を起こしにくい理由について説明します。

RNNとLSTMの更新式の比較

 ここでは、これまでの章で導出したRNNの更新式 ((13)式) とLSTMの更新式 ((40)式) を比較します。それぞれの数式は以下の通りです。


\begin{eqnarray}
\dfrac{\partial E}{\partial u_{j^\prime}^t} &=& \left( \displaystyle{\sum_{j=1}^{J}} \dfrac{\partial E}{\partial u_{j}^{t+1}} w_{j j^\prime} + d_{j^\prime}\right) f^\prime(u_{j^\prime}^t) \tag{13} \\
\dfrac{\partial E}{\partial u _ {j ^ \prime} ^ t} &=& \left(\alpha_{j^\prime} + \displaystyle{\sum_{k=1}^{J}}\left( \displaystyle{\sum_{j=1}^{J}}\dfrac{\partial E}{\partial u_{j}^{t+1}}w_{jk} + \beta_{k} + \gamma_{k} \right) \eta_{k j^\prime} \right) \\
& \times &g_{j^\prime}^{I, t}f^{\prime}(u_{j^\prime}^t) \tag{40}
\end{eqnarray}

 RNNでは活性化関数として、双曲線正接関数 {\rm tanh}が頻繁に使われ、その微分は必ず 1以下の値になります。

 また、(13)、(40)式はともに、 \dfrac{\partial E}{\partial u _ {j ^ \prime} ^ {t}}を計算する際、次の時刻の勾配 \dfrac{\partial E}{\partial u _ {j} ^ {t+1}}に活性化関数の微分 f^\prime(u_{j^\prime}^t)を掛けて計算されます。

 したがって、活性化関数に双曲線正接関数を用いた場合、 \dfrac{\partial E}{\partial u _ {j ^ \prime} ^ {t}}の絶対値は \dfrac{\partial E}{\partial u _ {j} ^ {t+1}}の絶対値よりも小さくなる場合があります。この繰り返しにより、時刻が小さくなるにつれて( t=Tから t=1に近づくにつれて)勾配の絶対値が急速に減少し、勾配消失問題が生じます。

 一方で、LSTMの更新式((40)式)でも同様に、 g _ {j ^ \prime} ^ {I, t}f ^ \prime(u _ {j ^ \prime} ^ t)のような 1以下の値を掛けています。しかし、(40)式では(13)式に比べて多くの和の項が含まれています。これにより、LSTMではRNNよりも勾配 \dfrac{\partial E}{\partial u _ {j^ \prime} ^ {t}}の絶対値が減少するスピードが抑えられ、勾配消失を起こしにくくなります。

結論

 RNNでは、出力層から伝播する誤差に、 1以下の値が繰り返し掛け合わされます。このため、誤差が入力側に進むにつれて、誤差の絶対値が急激に小さくなり、勾配消失が発生します。

 LSTMも同様に、誤差に  1 以下の値を掛けますが、更新式に含まれる和の項が多いため、誤差の絶対値が減少しにくくなっています。これにより、LSTMはRNNに比べて勾配消失を抑えることが可能です。

 LSTMはRNNよりも勾配消失を起こしにくいため、より多くの層を重ねることが可能です。その結果、LSTMは系列データの問題に対して、高い精度での予測が可能になりました。

補足:誤差逆伝播の更新式導出に用いた命題の証明

 この節では、RNNとLSTMの更新式導出に用いた以下の「命題:合成関数の微分法」を証明します。

命題:合成関数の微分
  w _ {ij} Eスカラー \boldsymbol{u} ^ t \ (t=2, \dots, T)をベクトルとする。また、 w _ {ij} \to (\boldsymbol{u} ^ 2, \dots, \boldsymbol{u} ^ T) (\boldsymbol{u} ^ 2, \dots, \boldsymbol{u} ^ T) \to Eの二つの写像を考える。このとき、 w _ {ij}による E微分 \dfrac{\partial E}{\partial w _ {ij}}は以下のようになる。

\begin{eqnarray}
\dfrac{\partial E}{\partial w _ {ij}} = \displaystyle{\sum_{t = 2} ^ {T}} \dfrac{\partial E}{\partial \boldsymbol{u}^t} \dfrac{\partial \boldsymbol{u}^t}{\partial w_{ij}}
\end{eqnarray}

(証明)

 \boldsymbol{u} ^ t = (u _ 1 ^ t, \dots, u _ {J_t} ^ t) ^ \topとし、


\begin{eqnarray}
\boldsymbol{U} = \left(
\begin{array}{c}
\boldsymbol{u}^2 \\
\vdots \\
\boldsymbol{u}^T \\
\end{array}
\right)
= \left(
\begin{array}{c}
u_1^2 \\
\vdots \\
u_{j^\prime}^t \\
\vdots \\
u_{J_T}^T \\
\end{array}
\right)
\end{eqnarray}

とします。このとき、 w _ {ij} \to \boldsymbol{U} \to Eです。したがって、多変数関数の合成関数の微分法より


\begin{eqnarray}
\dfrac{\partial E}{\partial w_{ij}} &=& \displaystyle{\sum_{t=2}^{T}} \displaystyle{\sum_{j^\prime=1}^{J_t}} \dfrac{\partial E}{\partial u_{j^\prime}^t} \dfrac{\partial u_{j^\prime}^t}{\partial w_{ij}} \\
&=& \displaystyle{\sum_{t=2}^{T}} \left( \dfrac{\partial E}{\partial u_1^t}, \dots, \dfrac{\partial E}{\partial u_{j^\prime}^t}, \dots, \dfrac{\partial E}{\partial u_{J_t}^t}   \right) 
\left(
\begin{array}{c}
\dfrac{\partial u_1^t}{\partial w_{ij}} \\
\vdots \\
\dfrac{\partial u_{j^\prime}^t}{\partial w_{ij}} \\
\vdots \\
\dfrac{\partial u_{J_t}^t}{\partial w_{ij}} \\
\end{array}
\right) \\
&=& \displaystyle{\sum_{t = 2} ^ {T}} \dfrac{\partial E}{\partial \boldsymbol{u}^t} \dfrac{\partial \boldsymbol{u}^t}{\partial w_{ij}}
\end{eqnarray}

となり、「命題:合成関数の微分法」が成り立ちます。