初心者データサイエンティストの備忘録

調べたことは全部ここに書いて自分の辞書を作る

ニューラルネットワークのドロップアウト

目次

はじめに

 近年、ニューラルネットワークは分類精度の高さや、様々なタスクへの応用が可能であることから、幅広い領域で使われています。私の場合、セキュリティ領域でニューラルネットワークを用いています。
 ところが、ニューラルネットワークはその複雑な構造から、誤った使い方をされてしまうことがあります。誤った使い方を防ぐためには、その基礎理論を学ぶことが重要だと私は考えています。私の場合、なっとく!ディープラーニングという本を通じて、ニューラルネットワークの理論を学んできました。
 本記事では、本書で学んだドロップアウトという手法について説明します。本記事より詳しい内容を確認したい方は、本書の第8章をご覧ください。

ドロップアウトの概要

 ドロップアウトとは、ニューラルネットワークにおける過学習を防ぐための仕組みです。ニューラルネットワークに含まれる重みをランダムに0にすることによって、より頑健なニューラルネットワークを構築することができます。
 本記事では、まず過学習を起こすニューラルネットワークを構築します。その後、学習の過程にドロップアウトを入れることによって、過学習が抑制されることを示します。なお、本記事を執筆する過程で書いたJuliaのコードはGitHubで公開しています。

問題設定

 MNISTデータセットの分類器を、三層のニューラルネットワークを用いて構築します。画像が28×28=784個の特徴量を持っているので、入力層が784個、隠れ層は300個、出力層は0から9の10個のニューラルネットワークを構築します。

過学習を起こすニューラルネットワーク

 まずは、ドロップアウトを入れずに学習を進めたニューラルネットワークを構築します。このとき、横軸にイテレーション、縦軸に学習データとテストデータに対する最小二乗誤差(Mean Square Error:MSE)をとった図が図1です。

図1:ドロップアウトを入れないニューラルネットワークイテレーションとMSEの関係

 図1を見ると、学習データのMSEは単調減少しているのに対し、イテレーションが15を超えたあたりからテストデータのMSEが増加し、その後減少していないことがわかります。これは、ニューラルネットワークが学習データの特徴を捉えすぎることにより、テストデータの分類がうまくいかなくなっている状態です。
 ニューラルネットワークは、その表現力の高さから、過学習を起こしやすいことが知られています。

ドロップアウトを入れたニューラルネットワーク

 ドロップアウトを入れて学習を進めたニューラルネットワークを構築します。このとき、横軸にイテレーション、縦軸に学習データとテストデータに対するMSEをとった図が図2です。

図2:ドロップアウトを入れたニューラルネットワークイテレーションとMSEの関係

 図2を見るとわかるように、学習データとテストデータのMSEが両方とも単調減少しています。これはドロップアウトを入れたことにより、過学習が抑制されているからです。

まとめ

 ニューラルネットワーク過学習を抑制するドロップアウトについて、簡単に説明しました。また、ドロップアウトを入れずに学習を進めたニューラルネットワークと、ドロップアウトを入れて学習を進めたニューラルネットワークのMSEの変化を比較しました。これにより、過学習を防ぐ手立てとしてドロップアウトが有効な場合があることが示されました。