初心者データサイエンティストの備忘録

調べたことは全部ここに書いて自分の辞書を作る

畳み込みネットワークの誤差逆伝播

はじめに

 現在、画像認識は自動車の自動運転や医療機関での画像診断など幅広い分野で使われています。また、画像認識に深層学習を用いた結果、精度が飛躍的に向上しました。
 画像認識で深層学習を使用するとき、しばしば畳み込みネットワーク(Convolutional Neural Network:CNN)が使われます。このCNNについて、順伝播を解説する記事は世の中に多くあります。しかし、誤差逆伝播を説明する記事はあまり見かけません。
 そこで本記事では、画像認識の深層学習モデルに多用されるCNNの誤差逆伝播の式について説明します。また、Pythonを用いて、導出した誤差逆伝播の実装を行います。なお、実装した結果はこちらで公開しています。

問題設定

 手書き数字の画像を集めたMNISTデータセットの分類器を図1のような構成で作ります。

図1:ニューラルネットワークの構成

 図1を見るとわかるように、入力画像を16枚の3×3のフィルターに通し(畳み込み層)、双曲線正接関数を通したあと、全結合層に流し込みます。全結合層の出力では、ソフトマックス関数によって出力を0~1の間になるように変換し、最終的に交差エントロピーで誤差を計算します。
 上記で出てきた双曲線正接関数とソフトマックス関数、交差エントロピーについて下記で補足します。

双曲線正接関数

 双曲線正接関数とは、下記で定義される活性化関数です。


\begin{eqnarray}
{\rm tanh}(x) = \dfrac{e^x - e^{-x}}{e^x + e^{-x}}
\end{eqnarray}

 誤差逆伝播の式を導出する際に、双曲線正接関数の微分が出てくるのでここで計算しておくと、


\begin{eqnarray}
\dfrac{d{\rm tanh}(x)}{dx} &=& \dfrac{e^x + e^{-x}}{e^x + e^{-x}} - \dfrac{(e^x-e^{-x})(e^x-e^{-x})}{(e^x+e^{-x})^2} \\
&=& 1- \left(\dfrac{e^x-e^{-x}}{e^x+e^{-x}}\right) \\
&=& 1-({\rm tanh}(x))^2
\end{eqnarray}

となります。

ソフトマックス関数

 ソフトマックス関数とは、下記で定義される活性化関数です。


\begin{eqnarray}
{\rm softmax}(x_k) = \dfrac{e^{x_k}}{\sum_{k^\prime=0}^9 e^{x_{k^\prime}}}
\end{eqnarray}

 今回の場合、MNISTデータセットのラベルが0~9の十値なので、分母の和はk=0から9となっています。
 ソフトマックス関数についても、後で微分した式を用いるので計算しておこうと思います。
 なお、 \dfrac{d {\rm softmax}(x_k)}{d x_i}を計算するのですが、i=kの場合とi \neq kの場合で結果が異なるので、場合分けして説明します。

  • i=kの場合

\begin{eqnarray}
\dfrac{d {\rm softmax}(x_k)}{d x_i} &=& \dfrac{e^{x_k}}{\sum_{k^\prime=0}^9 e^{x_{k^\prime}}} - \dfrac{(e^{x_k})^2}{(\sum_{k^\prime=0}^9 e^{x_{k^\prime}})^2} \\
&=& {\rm softmax}(x_k)(1-{\rm softmax}(x_k))
\end{eqnarray}
  • i \neq kの場合

\begin{eqnarray}
\dfrac{d {\rm softmax}(x_k)}{d x_i} &=& -\dfrac{e^{x_k}e^{x_i}}{(\sum_{k^\prime=0}^9 e^{x_{k^\prime}})^2} \\
&=& -{\rm softmax}(x_k){\rm softmax}(x_i)
\end{eqnarray}

となります。

交差エントロピー

 交差エントロピーとは、下記で定義される誤差関数です。


\begin{eqnarray}
L(\boldsymbol{x}, \boldsymbol{y}) = -\sum_{k=0}^9 y_k {\rm log}x_k
\end{eqnarray}

 ただし、\boldsymbol{x}=(x_0, x_1, \cdots, x_9)は予測値、\boldsymbol{y}=(y_0, y_1, \cdots, y_9)は正解ラベルです。これについても微分した式を計算すると、


\begin{eqnarray}
\dfrac{\partial L}{\partial x_k} = - \dfrac{y_k}{x_k}
\end{eqnarray}

となります。

im2col

 畳み込み演算をする際にim2colと呼ばれる処理を画像に施します。これにより、畳み込み演算を行列積で書くことができ、誤差逆伝播の式がわかりやすくなります。また、畳み込み演算を行列積で書けることにより、高速に計算可能な実装を行うことができます。
 im2colの具体的な方法や実装はim2col徹底理解という記事が分かりやすかったです。ここでは、im2col後の行列をC、フィルターをK としたときに、畳み込み演算はKCと書くことができるとだけ述べておきます。ただし、Kは図2で定義される行列です。

図2:Kの定義

文字の整理

 ここで、この後登場する文字を整理しておきます。

  • 畳み込み演算後の出力をX^{(10)}=KCとする。
  • 中間層の出力をX^{(11)} = {\rm tanh}(X^{(10)})とする。
  • 全結合層の出力をX^{(20)}=WX^{(11)}とする。ただし、Wは全結合層のパラメータである。
  • 出力層の出力をX^{(21)}={\rm softmax}(X^{(20)})とする。

誤差逆伝播の更新式の導出

 誤差逆伝播の更新式に必要な式は、損失関数を各パラメータで微分した下記の2式です。


\begin{eqnarray}
\dfrac{\partial L}{\partial W} \\
\dfrac{\partial L}{\partial K} 
\end{eqnarray}

 まずは簡単な


\begin{eqnarray}
\dfrac{\partial L}{\partial W} \\
\end{eqnarray}

から計算しようと思います。

\dfrac{\partial L}{\partial W}の導出

 行列WX ^ {(11)} X ^ {(20)}の成分をそれぞれw _ {mn} x _ n ^ {(11)} x _ m ^ {(20)}とします。また、X ^ {(20)}の成分数をN _ {(20)}とします。このとき、合成関数の微分の公式から


\begin{eqnarray}
\dfrac{\partial L}{\partial w_{mn}} &=& \sum_{i=1}^{N_{(20)}} \dfrac{\partial x_i^{(20)}}{\partial w_{mn}} \dfrac{\partial L}{\partial x_i^{(20)}} \\
&=& x_n^{(11)}\dfrac{\partial L}{\partial x_m^{(20)}}
\end{eqnarray}

となります。  次に、\dfrac{\partial L}{\partial x_m^{(20)}}を計算していきます。行列X ^ {(21)}の成分をx _ i ^ {(21)}とします。また、X^{(21)}の成分数をN_{(21)}とします。このとき、合成関数の微分の公式から


\begin{eqnarray}
\dfrac{\partial L}{\partial x_m^{(20)}} &=& \sum_{i=1}^{N_{(21)}} \dfrac{\partial x_i^{(21)}}{\partial x_m^{(20)}}\dfrac{\partial L}{\partial x_i^{(21)}} \\
&=& x_m^{(21)}(1-x_m^{(21)}) \left(-\dfrac{y_m}{x_m^{(21)}}\right) - \sum_{i \neq m}x_i^{(21)}x_m^{(21)}\left(-\dfrac{y_i}{x_i^{(21)}}\right) \\
&=& (1-x_m^{(21)})(-y_m)-\sum_{i\neq m}x_m^{(21)}(-y_i) \\
&=& x_m^{(21)}-y_m
\end{eqnarray}

となります。ここでは、上述したソフトマックス関数と交差エントロピー微分、および正解ラベルy_mについて、\displaystyle{\sum_{m=1}^9} y _ m = 1であることを用いました。
 以上より、


\begin{eqnarray}
\dfrac{\partial L}{\partial w_{mn}} = x_n^{(11)}(x_m^{(21)}-y_m)
\end{eqnarray}

が得られます。これを行列表記に変換すると、


\begin{eqnarray}
\dfrac{\partial L}{\partial W} = (X^{(21)}-Y)X^{(11)}
\end{eqnarray}

となります。

 \dfrac{\partial L}{\partial K}の導出

 行列Kの成分をk _ mとします。また、X^{(10)}の成分数をN_{(10)}とします。このとき、合成関数の微分の公式から


\begin{eqnarray}
\dfrac{\partial L}{\partial k_m} &=& \sum_{i=1}^{N_{(10)}} \dfrac{\partial x_i^{(10)}}{\partial k_m} \dfrac{\partial L}{\partial x_i^{(10)}} \\
&=& \sum_{i=1}^{N_{(10)}} c_{mi} \dfrac{\partial L}{\partial x_i^{(10)}}
\end{eqnarray}

となります。次に、\dfrac{\partial L}{\partial x_i^{(10)}}を計算していきます。


\begin{eqnarray}
\dfrac{\partial L}{\partial x_i^{(10)}} &=& \sum_{j=1}^{N_{(11)}} \dfrac{\partial x_j^{(11)}}{\partial x_i^{(10)}} \dfrac{\partial L}{\partial x_j^{(11)}} \\
&=& \left({\rm tanh}(x_i^{(10)})\right)^\prime \dfrac{\partial L}{\partial x_i^{(11)}}
\end{eqnarray}

となります。次に、 \dfrac{\partial L}{\partial x_i^{(11)}} を計算していきます。


\begin{eqnarray}
\dfrac{\partial L}{\partial x_i^{(11)}} &=& \sum_{j=1}^{N_{(20)}} \dfrac{\partial x_j^{(20)}}{\partial x_i^{(11)}} \dfrac{\partial L}{\partial x_j^{(20)}} \\
&=& \sum_{j=1}^{N_{(20)}}w_{ji} \dfrac{\partial L}{\partial x_j^{(20)}} \\
&=& \sum_{j=1}^{N_{(20)}}w_{ji}(x_j^{(21)}-y_j)
\end{eqnarray}

ここでは、 \dfrac{\partial L}{\partial x_j^{(20)}} = x_j^{(21)}-y_jを用いました。
 以上より、


\begin{eqnarray}
\dfrac{\partial L}{\partial k_m} &=& \sum_{i=1}^{N_{(10)}} c_{mi} \dfrac{\partial L}{\partial x_i^{(10)}} \\
&=& \sum_{i=1}^{N_{(10)}} c_{mi} \left( ({\rm tanh}(x_i^{(10)}))^\prime \dfrac{\partial L}{\partial x_i^{(11)}} \right) \\
&=& \sum_{i=1}^{N_{(10)}} c_{mi} \left( ({\rm tanh}(x_i^{(10)}))^\prime \sum_{j=1}^{N_{(20)}} w_{ji} (x_j^{(21)}-y_j)\right) \\
\end{eqnarray}

となります。最後にこれを行列表記すると、


\begin{eqnarray}
\dfrac{\partial L}{\partial K} &=& {\rm tanh}(X^{(10)})^\prime W (X^{(21)}-Y)C^\top
\end{eqnarray}

となります。ただし、 {\rm tanh}(X^{(10)})^\primeは、 {\rm tanh} X^{(10)}の各成分に適用し、微分して得られる行列です。

まとめ

 本記事では、CNNの誤差逆伝播の解説を行いました。
 具体的には、まず問題設定について説明しました。次に、今回のモデルで用いる双曲線正接関数とソフトマックス関数、交差エントロピーの定義を行い、それぞれを微分した式を導出しました。その後、誤差逆伝播による更新式 \dfrac{\partial L}{\partial W}\dfrac{\partial L}{\partial K}の導出を行いました。
 なお、本記事で導出した更新式を用いて、CNNを学習するコードはこちらに配置しました。興味がある方は、ぜひご覧ください。