初心者データサイエンティストの備忘録

調べたことは全部ここに書いて自分の辞書を作る

【PyTorch】one hot encodingのやり方

はじめに

 統計学において、分類問題を解く際に目的変数をone hot encodingすることは多々あります。本記事では、PyTorchにおいてカテゴリカル変数をone hot encodingする方法を紹介しようと思います。

フラグが1つの場合

方法1:torch.nn.functional.one_hotを使う

 一つ目の方法は、torch.nn.functional.one_hotメソッドを使う方法です。下記のようにして使います。

import torch
import torch.nn.functional as F

# カテゴリカル変数の定義
categorical = torch.tensor(
    [0, 3, 9]
)

# one hot encodingの実行
one_hot_encoded_1 = F.one_hot(categorical, num_classes=10) # num_classesにクラス数を指定する
print(one_hot_encoded_1)
# 下記のように表示される
# tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
# [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]])      

方法2:scatter_を使う

 二つ目の方法は、scatter_メソッドを使う方法です。下記のようにして使います。

import torch

# カテゴリカル変数を2次元に変更する
categorical_unsqueeze = torch.unsqueeze(categorical, dim=1)

# one hot encodingの実行
one_hot_encoded_2 = torch.zeros((3, 10)).scatter_(dim=1, index=categorical_unsqueeze, value=1)
print(one_hot_encoded_2)
# 下記のように表示される
# tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
# [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]])

 二つ目のscatter_メソッドを使う方法は、少し難しいので一行ごとに解説します。まずは、

categorical_unsqueeze = torch.unsqueeze(categorical, dim=1)

について説明します。
 このコードは、tensor([0, 3, 9])の次元を一つ増やして、tensor([[0], [3], [9]])に変換しています。scatter_メソッドは、input.scatter_(dim, index, value)の形式で使います。この際、inputindexの次元数が同じである必要があります。今回は、inputとして、3×10の二次元のテンソルを指定しているので、indexにあたるcategorical_unsqueezeも二次元に変換する必要があり、unsqueezeしています。

 次に、

one_hot_encoded_2 = torch.zeros((3, 10)).scatter_(dim=1, index=categorical_unsqueeze, value=1)

について、説明します。
 まずは、torch.zeros((3, 10))を用いて、要素が全て0の3×10の二次元のテンソルを作成します。その後、必要な箇所に1を挿入する作業をscatter_メソッドを用いて行っています。 ここで、scatter_メソッドの説明を簡単にします。
 scatter_メソッドは、input.scatter_(dim, index, value)としたときに、下記の作業を行います。ただし、inputは二次元とします。

# dim=0のとき
input[index[i][j]][j] = value

# dim=1のとき
input[i][index[i][j]] = value

 今回は、下記のルールにしたがって値を挿入したいのでscatter_(dim=1, index=categorical_unsqueeze, value=1)としています。

input[0][index[0][0]] = value
input[1][index[1][0]] = value
input[2][index[2][0]] = value

 以上のようにして、フラグが1つの場合のone hot encodingを行うことができます。

フラグが2つの場合

 下記のような形にone hot encodingをする方法について説明します。

[[0, 1],
 [3, 7],
 [4, 9]]
を10クラスのone hot enncoding
[[1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
 [0., 0., 0., 1., 0., 0., 0., 1., 0., 0.],
 [0., 0., 0., 0., 1., 0., 0., 0., 0., 1.]]
に変換する

 実は、この変換をする際にone_hotメソッドを使うことはできません。one_hotメソッドを使うと下記のように変換されてしまいます。

categorical_2 = torch.tensor(
    [[0, 1],
     [3, 7],
     [4, 9]]
)

print(F.one_hot(categorical_2, num_classes=10))

# 下記のように表示される
#tensor([[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
# [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]],
# [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
# [0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
# [[0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]]])

 これだと意図した結果と少し異なります。これを解決するにはscatter_メソッドを使う必要があります。今回の場合、下記のルールしたがって値を挿入します。

input[0][index[0][0]] = 1
input[0][index[0][1]] = 1

input[1][index[1][0]] = 1
input[1][index[1][1]] = 1

input[2][index[2][0]] = 1
input[2][index[2][1]] = 1

したがって、scatter_(dim=1, index=categorical_2, value=1)とします。コード全体は下記のようになります。

print(torch.zeros((3, 10)).scatter_(dim=1, index=categorical_2, value=1))

 以上より、フラグが2つの場合のone hot encodingができました。フラグが3つ以上の場合も同様の手順でone hot encodingすることができます。

まとめ

 PyTorchにおいて、one hot encodingをする方法を紹介しました。下記の手順にしたがって、one hot encodingを行います。

  • フラグが1つの場合:torch.nn.functional.one_hotメソッドもしくは、scatter_メソッドを用いる
  • フラグが2つ以上の場合:scatter_メソッドを用いる