初心者データサイエンティストの備忘録

調べたことは全部ここに書いて自分の辞書を作る

【PyTorch】Negative Log Likelihood・Kullback Leibler距離・Cross Entropyの関係

はじめに

 PyTorchのチュートリアルの「0.PyTorch入門 6.最適化」に、Negative Log Likelihood(以下、NLL)という損失関数が紹介されていました。私はこの損失関数を知らなかったので、調査してみました。その結果、NLLはKullback Leibler距離(以下、KL距離)、およびCross Entropy(以下、CE)と関連がある損失関数とわかりました。本記事では、NLLとKL距離、CEの関係についてまとめようと思います。

Kullback Leibler距離

定義

 まず、KL距離を定義します。データyがしたがう真の分布の確率関数をP(y)P(y)を推定するためのモデルの確率関数をQ(y \mid \theta)とします。\thetaはモデルに含まれるパラメータです。このとき、KL距離は次の式で定義されます。


D_{\rm KL}(P \mid \mid Q) = \displaystyle{\sum_{y}} P(y){\rm log}\dfrac{P(y)}{Q(y \mid \theta)}

 このKL距離について次の命題が成り立ちます。

 P=Qのとき、D_{\rm KL}(P \mid \mid Q)は最小値 0をとる。

 したがって、KL距離は損失関数として使うことができます。また、KL距離を損失関数としたとき、\thetaを推定する問題は、KL距離を最小化する問題に帰着できます。なお、KL距離は距離と呼ばれつつも正確には距離の定義を満たしません。
 また、データから計算されるKL距離は、 \boldsymbol{y} Jクラスの多項分布にしたがう場合、次の式で定義されます。


\hat{D}_{\rm KL}(\boldsymbol{Y} \mid \mid \hat{\boldsymbol{Y}}) = \dfrac{1}{N}\displaystyle{\sum_{i=1}^N}\displaystyle{\sum_{j=1}^J} y_{ij}{\rm log}\dfrac{y_{ij}}{\hat{y}_{ij}}

 なお、上の式で現れる文字は下記の通りです。

  •  N:サンプルサイズ
  •  \boldsymbol{y} _ i = (y _ {i1}, \cdots, y _ {iJ}):one-hot-encodingされている i番目データの正解ラベル
  •  \hat{\boldsymbol{y}} _ i = (\hat{y} _ {i1}, \cdots, \hat{y} _ {iJ}) \sum_ {j=1} ^ J \hat{y} _ {ij} = 1を満たす i番目のデータの予測値
  • \boldsymbol{Y} = (\boldsymbol{y} _ i) _ {i=1, \cdots, N}:全データの正解ラベル
  • \hat{\boldsymbol{Y}} = (\hat{\boldsymbol{y}} _ i) _ {i=1, \cdots, N}:全データの予測値

計算例

 KL距離の計算例として、 y \in {0, 1}の二項分布を考えてみます。真の分布を P(y) = p ^ y (1-p) ^ {(1-y)}、モデルを Q(y \mid \theta) = \theta ^ y (1-\theta) ^ {(1-y)}とします。このとき、KL距離は次のように計算されます。


\begin{eqnarray}
D_{\rm KL}(P \mid \mid Q) &=& \displaystyle{\sum_{y=0}^1} P(y){\rm log}\dfrac{P(y)}{Q(y \mid \theta)} \\
&=& (1-p) {\rm log}\dfrac{1-p}{1-\theta} + p {\rm log}\dfrac{p}{\theta}
\end{eqnarray}

 また、データから計算されるKL距離は正解ラベルを (y, 1-y)、予測値を (\hat{y}, 1-\hat{y})とすると、 N=1のとき


\hat{D}_{\rm KL}(\boldsymbol{Y} \mid \mid \hat{\boldsymbol{Y}}) = (1-y) {\rm log}\dfrac{1-y}{1-\hat{y}} + y {\rm log}\dfrac{y}{\hat{y}}

となります。

Cross Entropy

定義

 次に、CEを定義します。まず、KL距離を分解すると


D_{\rm KL}(P \mid \mid Q) = \displaystyle{\sum_{y}} P(y){\rm log}P(y)-\displaystyle{\sum_{y}} P(y){\rm log}{Q(y \mid \theta)}

となります。このとき、第2項をCEと定義します。つまり、


{\rm CE}(P \mid \mid Q) = -\displaystyle{\sum_{y}} P(y){\rm log}{Q(y \mid \theta)}

です。このとき、次の命題が成り立ちます。

 D_{\rm KL} \thetaによる最小化  \Leftrightarrow  {\rm CE} \thetaによる最小化

 この命題は、


D_{\rm KL}(P \mid \mid Q) = \displaystyle{\sum_{y}} P(y){\rm log}P(y)+{\rm CE}(P \mid \mid Q)

であることから、自明に成り立ちます。
 以上より、KL距離の最小化問題は、CEの最小化問題に帰着できます。
 また、データから計算されるCEは、 \boldsymbol{y} Jクラスの多項分布したがう場合、次の式で定義されます。


\hat{{\rm CE}}(P \mid \mid Q) = -\dfrac{1}{N}\displaystyle{\sum_{i=1}^N}\displaystyle{\sum_{j=1}^J} y_{ij}{\rm log}{\hat{y}_{ij}}

計算例

 CEの計算はKL距離と似ているので、割愛します。

Negative Log Likelihood

定義

 NLLはデータに対して定義されます。 \boldsymbol{y} Jクラスの多項分布にしたがう場合、次の式で定義されます。


\hat{{\rm NLL}}(\boldsymbol{X}, \boldsymbol{Y}) = -\dfrac{1}{N}\displaystyle{\sum_{i=1}^N}\displaystyle{\sum_{j=1}^J} x_{ij} y_{ij}

 なお、 \boldsymbol{X} = (x _ {ij})です。
 この定義より、次の命題が成り立ちます。

 \boldsymbol{X} = {\rm log}\hat{\boldsymbol{Y}}のとき、 {\rm CE} = {\rm NLL}となる。ただし、 {\rm log}\hat{\boldsymbol{Y}} = ({\rm log}(\hat{y}_{ij}))である。

 これは、CEとNLLの定義から自明です。

Negative Log Likelihood・KL距離・CrossEntropyの関係のまとめ

 本記事では、NLL・KL距離・CEに以下の関係があることを説明しました。

  • KL距離の最小化問題とCEの最小化問題は同値
  •  \boldsymbol{X} = {\rm log}\hat{\boldsymbol{Y}}のとき、 {\rm CE} = {\rm NLL}となる

PyTorchでのNLLの使いどころ

 クラス分類をするモデルを考えます。このとき、クラス分類するモデルのよくある構成を図1に示しました。

図1:クラス分類するモデルのよくある構成
   ネットワーク等の出力zをソフトマックス関数で要素の和が1になるように変換し、変換した値と正解ラベルのCEを求める、というモデルです。このとき、CEはPyTorchで提供されているメソッドを組み合わせ、cross_entropy(softmax(z), y)、もしくはnll_loss(log_softmax(z), y)とすることで求められます。これらの関係は下記のようになっています。

図2:2種類のCrossEntropyの求め方

 cross_entropy(softmax(z), y)の方がメソッド名がcross_entropyである分直感的ですが、計算速度はnll_loss(log_softmax(z), y)の方が早く、数値計算も安定しているようです(参考文献)。

まとめ

 本記事では、以下の内容をまとめました。

  • KL距離の最小化問題とCEの最小化問題は同値
  •  \boldsymbol{X} = {\rm log}\hat{\boldsymbol{Y}}のとき、 {\rm CE} = {\rm NLL}となる
  • PyTorchでは、CEをcross_entropy(softmax(z), y)、もしくはnll_loss(log_softmax(z), y)で計算できる
    • cross_entropy(softmax(z), y)の方が直感的であるが、nll_loss(log_softmax(z), y)の方が計算速度が速く、数値計算も安定している