初心者データサイエンティストの備忘録

調べたことは全部ここに書いて自分の辞書を作る

PyTorchのcatとstackの違い

はじめに

 PyTorchにおいてテンソルを結合するメソッドは2種類あります。catstackです。本記事ではこれら2種類のメソッドの違いを説明します。

catとstackの違い

 結論から書くと、catは次元を増やさずにテンソルを結合します。一方で、stackは次元を増やしてテンソルを結合します。これをイメージしたものが、図1です。

図1:catとstackによる結合の違い

 例を挙げてcatstackの違いを説明します。

例1:1次元のテンソル

# import
import torch

# テンソルの定義
t1 = torch.tensor([1])
t2 = torch.tensor([2])

cat_t1_t2 = torch.cat([t1, t2], dim=0)
# -> tensor([1, 2])となる

stack_t1_t2 = torch.stack([t1, t2], dim=0)
# -> tensor([[1], [2]])となる

 なぜこのようになるのか、図2で解説します。

図2:1次元のテンソルのcatとstackによる結合

例2:2次元のテンソル

# import
import torch

# テンソルの定義
t3 = torch.tensor(
    [[1, 2],
     [3, 4]]
)

t4 = torch.tensor(
    [[5, 6],
     [7, 8]]
)

cat_t1_t2_dim0 = torch.cat([t1, t2], dim=0)
# ->  tensor([[1, 2],[3, 4],[5, 6],[7, 8]])となる

cat_t1_t2_dim1 = torch.cat([t1, t2], dim=1)
# -> tensor([[1, 2, 5, 6], [3, 4, 7, 8]])となる

stack_t1_t2_dim0 = torch.stack([t1, t2], dim=0)
# -> tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])となる

stack_t1_t2_dim1 = torch.stack([t1, t2], dim=1)
# -> tensor([[[1, 2], [5, 6]], [[3, 4], [7, 8]]])となる

 なぜこのようになるのか、図3で解説します。

図3:2次元のテンソルのcatとstackによる結合

まとめ

 PyTorchにおけるテンソルを結合するメソッドcatstackの違いを説明しました。