初心者データサイエンティストの備忘録

調べたことは全部ここに書いて自分の辞書を作る

【PyTorch】テンソルに対する勾配とヤコビ行列

はじめに

 PyTorchのチュートリアルを勉強しています。その中の「0. PyTorch入門 5.自動微分」で分からなかったことがあったので勉強し、まとめてみます。

わからなかったこと

 「0. PyTorch入門 5.自動微分」の最後の方に「補注:テンソルに対する勾配とヤコビ行列」として、出力が多次元テンソルの場合のbackwardの説明が書かれています。しかし、私は最初この説明では何を言っているのか理解することができませんでした。特になぜ下記のコードの出力が
[[4., 2., 2., 2., 2.],
[2., 4., 2., 2., 2.],
[2., 2., 4., 2., 2.],
[2., 2., 2., 4., 2.],
[2., 2., 2., 2., 4.]])
となるのか、理解できませんでした。

inp = torch.eye(5, requires_grad=True)
out = (inp+1).pow(2)
out.backward(torch.ones_like(inp), retain_graph=True)
print("First call\n", inp.grad)

 本記事では上記のコードがどのような計算をしているのか説明しようと思います。

結論

 backwardでgradを計算するとき、内部では下記の計算を行っているようです。

  1. outの一番内側のベクトルのうち、0番目の要素についてヤコビ行列を計算する。
  2. 1で計算されたヤコビ行列に、backwardの引数の一番内側のベクトルを左から掛ける。
  3. 2で計算された結果にinpの一番内側のベクトルの値を代入する。
  4. 1~3をoutの全ての要素に行う。

簡単な例

 下記のコードを例にして説明します。

inp = torch.tensor([1, 0], dtype=torch.float, requires_grad=True)
out = (inp+1).pow(2)
out.backward(torch.tensor([1, 1]))
print(inp.grad)
# 出力は[4, 2]

outの一番内側のベクトルのうち、0番目の要素についてヤコビ行列を計算する

 inp = [  x_1, x_2 ]とします。このとき、out = (inp+1).pow(2)なので、out = [  (x _ 1+1) ^ 2, (x _ 2+1) ^ 2 ]となります。便宜上、 y _ 1 = (x _ 1+1) ^ 2, y _ 2 = (x _ 2+1) ^ 2とします。以上を用いてヤコビ行列を計算すると、

 \displaystyle
\begin{eqnarray}
J=
\begin{pmatrix}
\dfrac{\partial y_1}{\partial x_1} & \dfrac{\partial y_1}{\partial x_2} \\
\dfrac{\partial y_2}{\partial x_1} & \dfrac{\partial y_2}{\partial x_2} \\
\end{pmatrix}
=
\begin{pmatrix}
2(x_1+1) & 0 \\
0 &  2(x_2+1) \\
\end{pmatrix}
\end{eqnarray}

となります。

1で計算されたヤコビ行列に、backwardの引数の一番内側のベクトルを左から掛ける

 上記で計算されたヤコビ行列にbackwardの引数[1, 1]を左から掛けます。すなわち、

 \displaystyle
\begin{eqnarray}
(1, 1)J=(2(x_1+1), 2(x_2+1))
\end{eqnarray}

となります。

2で計算された結果にinpの一番内側のベクトルの値を代入する

次に、上記で計算された (2(x _ 1+1), 2(x _ 2+1))に、inp = [  x _ 1, x _ 2 ] = [1, 0]を代入します。代入した結果は、[4, 2]です。

1~3をoutの全ての要素に行う

 今回は、outの次元が1だったので3で計算は終わりです。したがって、inp.gradは[4, 2]となります。

最初に提示したコードの場合

 本記事の最初に提示したコード

inp = torch.eye(5, requires_grad=True)
out = (inp+1).pow(2)
out.backward(torch.ones_like(inp), retain_graph=True)
print("First call\n", inp.grad)

の場合、どのような過程を経てinp.gradが計算されるのか見ていきます。

outの一番内側のベクトルのうち、0番目の要素についてヤコビ行列を計算する

 inp = [  x_1, x_2, x_3, x_4, x_5 ]とします。このとき、out = (inp+1).pow(2)なので、out = [  (x _ 1+1) ^ 2, (x _ 2+1) ^ 2, (x _ 3+1) ^ 2, (x _ 4+1) ^ 2, (x _ 5+1) ^ 2 ]となります。便宜上、 y _ 1 = (x _ 1+1) ^ 2, y _ 2 = (x _ 2+1) ^ 2, y _ 3 = (x _ 3+1) ^ 2, y _ 4 = (x _ 4+1) ^ 2,  y _ 5 = (x _ 5+1) ^ 2とします。以上を用いてヤコビ行列を計算すると、

 \displaystyle
\begin{eqnarray}
J=
\begin{pmatrix}
\dfrac{\partial y_1}{\partial x_1} & \dfrac{\partial y_1}{\partial x_2} & \dfrac{\partial y_1}{\partial x_3} & \dfrac{\partial y_1}{\partial x_4} & \dfrac{\partial y_1}{\partial x_5} \\
\dfrac{\partial y_2}{\partial x_1} & \dfrac{\partial y_2}{\partial x_2} & \dfrac{\partial y_2}{\partial x_3} & \dfrac{\partial y_2}{\partial x_4} & \dfrac{\partial y_2}{\partial x_5} \\
\dfrac{\partial y_3}{\partial x_1} & \dfrac{\partial y_3}{\partial x_2} & \dfrac{\partial y_3}{\partial x_3} & \dfrac{\partial y_3}{\partial x_4} & \dfrac{\partial y_3}{\partial x_5} \\
\dfrac{\partial y_4}{\partial x_1} & \dfrac{\partial y_4}{\partial x_2} & \dfrac{\partial y_4}{\partial x_3} & \dfrac{\partial y_4}{\partial x_4} & \dfrac{\partial y_4}{\partial x_5} \\
\dfrac{\partial y_5}{\partial x_1} & \dfrac{\partial y_5}{\partial x_2} & \dfrac{\partial y_5}{\partial x_3} & \dfrac{\partial y_5}{\partial x_4} & \dfrac{\partial y_5}{\partial x_5} \\
\end{pmatrix}
=
\begin{pmatrix}
2(x_1+1) & 0 & 0 & 0 & 0\\
0 &  2(x_2+1) & 0 & 0 & 0\\
0 & 0 & 2(x_3+1) & 0 & 0\\
0 & 0 & 0 & 2(x_4+1) & 0\\
0 & 0 & 0 & 0 & 2(x_5+1)\\
\end{pmatrix}
\end{eqnarray}

となります。

1で計算されたヤコビ行列に、backwardの引数の一番内側のベクトルを左から掛ける

 上記で計算されたヤコビ行列にbackwardの引数[1, 1, 1, 1, 1]を左から掛けます。すなわち、

 \displaystyle
\begin{eqnarray}
(1, 1, 1, 1, 1)J=(2(x_1+1), 2(x_2+1), 2(x_3+1), 2(x_4+1), 2(x_5+1))
\end{eqnarray}

となります。

2で計算された結果にinpの一番内側のベクトルの値を代入する

次に、上記で計算された (2(x _ 1+1), 2(x _ 2+1), 2(x _ 3+1), 2(x _ 4+1), 2(x _ 5+1))に、inp[0] = [  x _ 1, x _ 2, x _ 3 x _ 4, x _ 5 ] = [1, 0, 0, 0, 0]を代入します。代入した結果は、[4, 2, 2, 2, 2]です。

1~3をoutの全ての要素に行う

 上記の計算をinp[1]、inp[2]、inp[3]、inp[4]に対して計算します。その結果が、
[[4., 2., 2., 2., 2.],
[2., 4., 2., 2., 2.],
[2., 2., 4., 2., 2.],
[2., 2., 2., 4., 2.],
[2., 2., 2., 2., 4.]])
です。

まとめ

 本記事では、PyTorchのbackwardを用いてgradを計算する手順について説明しました。計算手順は下記の通りです。

  1. outの一番内側のベクトルのうち、0番目の要素についてヤコビ行列を計算する。
  2. 1で計算されたヤコビ行列に、backwardの引数の一番内側のベクトルを左から掛ける。
  3. 2で計算された結果にinpの一番内側のベクトルの値を代入する。
  4. 1~3をoutの全ての要素に行う。

 ただ、計算手順は理解できたものの、正直、backwardになぜ引数を設定するのかまではわかりませんでした。もし、ご存じの方がいれば教えて頂きたいです。