初心者データサイエンティストの備忘録

調べたことは全部ここに書いて自分の辞書を作る

【PyTorch】viewとreshapeの違い

はじめに

 PyTorchには、テンソルを変形するメソッドとしてtorch.Tensor.viewtorch.Tensor.reshapeが用意されています。本記事では、メソッドviewreshapeの違いについてまとめます。

本記事のサマリ

  • viewは要素が順に並んでいるときしか使えない。reshapeは、要素が順に並んでいないときでも、テンソルを変形できる
  • テンソルの要素がメモリ上で順に並んでいるとは、テンソルの要素が連続したメモリに配置されているということ
  • テンソルを転置すると、テンソルの要素がメモリ上で順に並んだ状態ではなくなり、viewを使えなくなる

viewreshapeの違い

 viewreshapeの違いを図1にまとめました。viewは要素が順に並んでいるときしか使えないです。reshapeは、要素が順に並んでいないときでも、テンソルを変形できます。

図1:viewとreshapeの違い

 なお、要素が順に並んでいないときでも、メソッドcontiguousを使って要素を順に並び替えることができます。その後viewを適用すれば、エラーは出力されません。その場合、下記のようにして使います。

import torch

# テンソルを作成
x = torch.tensor([[1, 2, 3], [4, 5, 6]])

# 転置(このときに、テンソルの要素が順に並んでいない状態になる)
y = torch.t(x)

# viewを適用するとRuntimeErrorを出力する
y_view = y.view(2, 3)
# -> RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

# contiguousを適用してから、viewを適用するとRuntimeErrorは出力されない
y_view = y.contiguous().view(2, 3)

テンソルの要素がメモリ上の順に並んでいるとは?

 ここまで、「テンソルの要素がメモリ上の順に並んでいる」と書いてきました。これがどういう意味を持つのか、もう少し詳しく説明します。
 PyTorchでは、テンソルが作成されたとき、その要素はメモリ上にraw majorと呼ばれる方式で並びます。例えば、2×3のテンソルxを作成したとき、その要素はメモリ上に図2のように並びます。

図2:テンソルの要素のメモリ上での並び方

 この状態を「テンソルの要素がメモリ上の順に並んでいる」と呼んでいます。

テンソルを呼び出すときの挙動

 テンソルが呼び出されたとき、内部では下記の挙動をしています。

  1. x[0, 0]の値をメモリから取り出す
  2. strideに基づいて、各要素の値をメモリから取り出す

 01において、「x[0, 0]の値をメモリから取り出す」と書きました。これは、テンソルxx[0, 0]のメモリ番号が記録されており、その情報を基に値を取り出せます。なお、x[0, 0]のメモリ番号は次のコードで確認できます。

print(x.data_ptr())

# または
print(x[0, 0].data_ptr())

 次に、02の説明をします。
 x[0, 0]より後の要素を取り出すときは、テンソルに記録されているstrideというものを使います。strideとは、各次元において添え字を1つ増やすときに何個隣のメモリを見れば良いのか、を表す指標です。例えば、xの場合、strideは(3, 1)と記録されています。これは、0次元目の添え字を1つ増やすときは3個隣のメモリを、1次元目の添え字を1つ増やすときは1個隣のメモリを見よという意味です。つまり、x[1, 0]を取り出すには、x[0, 0]の3個隣のメモリを見ることで取り出し、x[0, 1]を取り出すには、x[0, 0]の1個隣のメモリを見ることで取り出します(図3)。

図3:strideの使い方

 なお、テンソルのstrideは下記のコードで確認できます。

print(x.stride())

テンソルを転置すると何が起きるか?

 テンソルを転置すると、何が起きるか見ていきます。
 PyTorchにおいては、テンソルを転置しても新しくメモリ上に要素が配置されるわけではありません。strideの値が変わるだけです。今回の場合、xを転置したyのstrideは、xのstrideをひっくり返した(1, 3)となります。したがって、yについて下記のことがいえます。

  • y[0, 0]のメモリ番号はx[0, 0]と同じ
  • y[i+1, j]を取り出すときには、y[i, j]の1個隣を見る
  • y[i, j+1]を取り出すときには、y[i, j]の3個隣を見る

 具体的には、図4のようになります。

図4:yの要素とメモリの位置関係

 このとき、メモリ上でyの要素は図5のように並び、raw majorになっていません。したがって、「テンソルの要素がメモリ上の順に並んでいない」状態となり、viewを使うことができなくなります。

図5:転置したテンソルの各要素のメモリ上での並び方

まとめ

  • viewは要素が順に並んでいるときしか使えない。reshapeは、要素が順に並んでいないときでも、テンソルを変形できる
  • テンソルの要素がメモリ上で順に並んでいるとは、テンソルの要素が連続したメモリに配置されているということ
  • テンソルを転置すると、テンソルの要素がメモリ上で順に並んだ状態ではなくなり、viewを使えなくなる

参考文献