初心者データサイエンティストの備忘録

調べたことは全部ここに書いて自分の辞書を作る

sklearn.tree.DecisionTreeClassifierのccp_alphaについて

サマリ

  • sklearn.tree.DecisionTreeClassifierは最初に木を成長させてから、枝刈りをするという手順で決定木を作成する。
  • 木を成長させるときは、不純度が小さくなるような頂点を新たに生成することによって行う。
  • 枝刈りは、「最弱リンク枝刈り」という手法を用いて行う。sklearn.tree.DecisionTreeClassifierのccp_alphaはこの手法に関連したパラメータである。
    • ccp_alphaの値が小さい場合、出力される決定木は大きくなり、ccp_alphaの値が大きい場合、出力される決定木は小さくなる。

sklearn.tree.DecisionTreeClassifierの説明

sklearn.tree.DecisionTreeClassifierについて説明されている記事は世の中に多くあるので、ここでは割愛します。公式サイトは下記です。

scikit-learn.org

木を成長させる方法

木を成長させる方法については、下記のQiitaの記事がわかりやすかったです。

qiita.com

ポイントとしては、木を成長させる際に、不純度が小さくなるように頂点を分割することです。Qiitaの記事よりも厳密に数式で表現すると下記のようになります。

  • 変数の設定
    • サンプルに対する添え字:i=1, \cdots, N
    • 特徴量に対する添え字:j=1, \cdots, J
    • 特徴量:\boldsymbol{x}=(x_{ij})
    • 目的変数:\boldsymbol{y}=(y_i)
    • 分割後の頂点の左側の領域:R_L = \left\{x_j \mid x_j \leq s \right\}
    • 分割後の頂点の右側の領域:R_R =  \left\{x_j \mid x_j > s \right\}
    • 分割後の頂点の左側の領域に含まれるサンプルの個数:N_L = \# \left\{\boldsymbol{x} _ i \mid \boldsymbol{x} _ i \in R_L \right\}
    • 分割後の頂点の右側の領域に含まれるサンプルの個数:N_R = \# \left\{\boldsymbol{x} _ i \mid \boldsymbol{x} _ i \in R_R \right\}
    • 分割後の頂点の左側の領域に含まれる観測値の割合:\displaystyle{ \hat{p} _ {Lk} = \frac{1}{N_L} \sum_{\boldsymbol{x} _ i \in R_L}I(y_i=k)}
    • 分割後の頂点の右側の領域に含まれる観測値の割合:\displaystyle{ \hat{p} _ {Rk} = \frac{1}{N_R} \sum_{\boldsymbol{x} _ i \in R_R}I(y_i=k)}

以上のような設定で、下記の式を最小化するように頂点を分割します。

\displaystyle{
{\rm min}_{j, s}  \left\{ N _ L Q _ L(j, s) + N _ R Q _ R(j, s) \right\}
}

上記の式に出てくるQ_L, Q_Rは不純度と呼ばれる式で、誤分類率、ジニ指数、交差エントロピーなどが使われます。それぞれ下記の式です。

DecisionTreeClassifierにおいては、引数criterionによって任意の不純度を設定することができます。

最弱リンク枝刈り

理論

上述した方法で木を成長させたあとに、木の枝刈りを行います。これは、過学習を防ぐために行う作業です。DecisionTreeClassifierにおいては、「最弱リンク枝刈り」という手法を用いて枝刈りを行います。

「最弱リンク枝刈り」とは、下記の式を最大にするように枝を刈る手法です。

\displaystyle{
C_\alpha(T) = N_L Q_L + N_R Q_R +\alpha |T|
}

ただし、|T|は木の終端頂点の個数です。上記の式の\alphaは、DecisionTreeClassifierにおいては、引数ccp_alphaによって設定できます。 \alphaが0に近いと終端頂点の個数が大きい木(深い木)が選択され、大きい値だと終端頂点の個数が小さい木(浅い木)が選ばれます。

\alphaの値を変えて、木の深さを確認

スクリプトは下記の通りです。

スクリプト中の「ccp_alphaと木の深さの関係を図示」を見ると、ccp_alphaが大きくなると浅い木が出力されていることが分かります。