はじめに
統計学において、分類問題を解く際に目的変数をone hot encodingすることは多々あります。本記事では、PyTorchにおいてカテゴリカル変数をone hot encodingする方法を紹介しようと思います。
フラグが1つの場合
方法1:torch.nn.functional.one_hotを使う
一つ目の方法は、torch.nn.functional.one_hotメソッドを使う方法です。下記のようにして使います。
import torch import torch.nn.functional as F # カテゴリカル変数の定義 categorical = torch.tensor( [0, 3, 9] ) # one hot encodingの実行 one_hot_encoded_1 = F.one_hot(categorical, num_classes=10) # num_classesにクラス数を指定する print(one_hot_encoded_1) # 下記のように表示される # tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], # [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]])
方法2:scatter_を使う
二つ目の方法は、scatter_メソッドを使う方法です。下記のようにして使います。
import torch # カテゴリカル変数を2次元に変更する categorical_unsqueeze = torch.unsqueeze(categorical, dim=1) # one hot encodingの実行 one_hot_encoded_2 = torch.zeros((3, 10)).scatter_(dim=1, index=categorical_unsqueeze, value=1) print(one_hot_encoded_2) # 下記のように表示される # tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], # [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]])
二つ目のscatter_メソッドを使う方法は、少し難しいので一行ごとに解説します。まずは、
categorical_unsqueeze = torch.unsqueeze(categorical, dim=1)
について説明します。
このコードは、tensor([0, 3, 9])の次元を一つ増やして、tensor([[0], [3], [9]])に変換しています。scatter_メソッドは、input.scatter_(dim, index, value)の形式で使います。この際、inputとindexの次元数が同じである必要があります。今回は、inputとして、3×10の二次元のテンソルを指定しているので、indexにあたるcategorical_unsqueezeも二次元に変換する必要があり、unsqueezeしています。
次に、
one_hot_encoded_2 = torch.zeros((3, 10)).scatter_(dim=1, index=categorical_unsqueeze, value=1)
について、説明します。
まずは、torch.zeros((3, 10))を用いて、要素が全て0の3×10の二次元のテンソルを作成します。その後、必要な箇所に1を挿入する作業をscatter_メソッドを用いて行っています。
ここで、scatter_メソッドの説明を簡単にします。
scatter_メソッドは、input.scatter_(dim, index, value)としたときに、下記の作業を行います。ただし、inputは二次元とします。
# dim=0のとき input[index[i][j]][j] = value # dim=1のとき input[i][index[i][j]] = value
今回は、下記のルールにしたがって値を挿入したいのでscatter_(dim=1, index=categorical_unsqueeze, value=1)としています。
input[0][index[0][0]] = value input[1][index[1][0]] = value input[2][index[2][0]] = value
以上のようにして、フラグが1つの場合のone hot encodingを行うことができます。
フラグが2つの場合
下記のような形にone hot encodingをする方法について説明します。
[[0, 1], [3, 7], [4, 9]] を10クラスのone hot enncoding [[1., 1., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 1., 0., 0., 0., 1., 0., 0.], [0., 0., 0., 0., 1., 0., 0., 0., 0., 1.]] に変換する
実は、この変換をする際にone_hotメソッドを使うことはできません。one_hotメソッドを使うと下記のように変換されてしまいます。
categorical_2 = torch.tensor(
[[0, 1],
[3, 7],
[4, 9]]
)
print(F.one_hot(categorical_2, num_classes=10))
# 下記のように表示される
#tensor([[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
# [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]],
# [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
# [0, 0, 0, 0, 0, 0, 0, 1, 0, 0]],
# [[0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]]])
これだと意図した結果と少し異なります。これを解決するにはscatter_メソッドを使う必要があります。今回の場合、下記のルールしたがって値を挿入します。
input[0][index[0][0]] = 1 input[0][index[0][1]] = 1 input[1][index[1][0]] = 1 input[1][index[1][1]] = 1 input[2][index[2][0]] = 1 input[2][index[2][1]] = 1
したがって、scatter_(dim=1, index=categorical_2, value=1)とします。コード全体は下記のようになります。
print(torch.zeros((3, 10)).scatter_(dim=1, index=categorical_2, value=1))
以上より、フラグが2つの場合のone hot encodingができました。フラグが3つ以上の場合も同様の手順でone hot encodingすることができます。
まとめ
PyTorchにおいて、one hot encodingをする方法を紹介しました。下記の手順にしたがって、one hot encodingを行います。
- フラグが1つの場合:
torch.nn.functional.one_hotメソッドもしくは、scatter_メソッドを用いる - フラグが2つ以上の場合:
scatter_メソッドを用いる