はじめに
PyTorchのチュートリアルを勉強しています。本記事では「0. PyTorch入門」の「 7. モデルの保存・読み込み」を学んだ結果をまとめようと思います。
モデルの保存・読み込み方法
PyTorchで作成したモデルの保存・読み込み方法を図1にまとめました。

具体的なコードは下記です。
モデルの重みだけを保存する
下記のようにして、モデルの重みを保存します。なお、ファイルの拡張子は.pthが使われることが多いようです。
import torch from torchvision import models # 画像認識モデルVGG16の訓練済みモデルを読み込む(例) model = models.vgg16(pretrained=True) # モデルの重みのみを保存する torch.save(model.state_dict(), 保存先のパス)
読み込む際は、先にモデルのインスタンスを作成しておき、そこに保存された重みを設定します。
import torch from torchvision import models # モデルのインスタンスを作成 model = models.vgg16() # モデルの重みの読み込み model_weights = torch.load(モデルのパス) # モデルの重みをモデルのインスタンスに設定 model.load_state_dict(model_weights)
モデルの重みだけでなく構成も保存するーPyTorchでしか読み込めない方法
下記のようにして、モデルの重みや構成を保存します。こちらも、ファイルの拡張子は.pthが使われることが多いようです。
import torch # モデル全体の保存 torch.save(model, 保存先のパス)
下記のコードで読み込みます。
import torch # モデル全体を読み込む model = torch.load(モデルのパス)
モデルの重みだけでなく構成も保存するーPyTorch以外でも読み込める方法
図1中のNo.1とNo.2の保存方法は、PyTorchでモデルを読み込み、推論することができます。しかし、No.3の保存方法では、PyTorchでモデルを読み込み、推論することができません。
本節ではNo.3で保存されたモデルの読み込みと推論の仕方について説明します。
まず、onnx形式で保存されたモデルを読み込み、モデルに問題ないこととモデルの構造を確認します。
import onnx # onnx形式で保存されたモデルの読み込み model = onnx.load(モデルのパス) # モデルに問題ないことの確認 onnx.checker.check_model(model) # モデルの構造の表示 print(onnx.helper.printable_graph(model.graph))
モデルの構造は下記のように表示されました。
graph main_graph ( %input.1[FLOAT, 1x3x224x224] ) initializers ( %features.0.weight[FLOAT, 64x3x3x3] %features.0.bias[FLOAT, 64] ~(略)~
次に推論を実行します。下記のコードで実行します。
import numpy as np import onnxruntime # 実行セクションを作成 sess = onnxruntime.InferenceSession( モデルのパス, providers=["CUDAExecutionProvider", "CPUExecutionProvider"] ) # データ作成 rng = np.random.default_rng() input_image = rng.random((1, 3, 224, 224)).astype(np.float32) # 上記でグラフの構造を表示したときのgraph main_graph内に記載されているサイズを入力データのサイズとする # 推論 output = sess.run( None, {"input.1":input_image} # 上記でグラフの構造を表示したときのgraph main_graph内に記載されている変数を入力変数とする ) print("推論されたカテゴリ:", output[0].argmax(1))
まとめ
本記事では、PyTorchで作成したモデルの保存方法と読み込み方についてまとめました。下記の3種類の方法を適切に使い分けられるようにしていきたいです。
- モデルの重みだけを保存する
- モデルの重みだけでなく構成も保存する
- PyTorchでしか読み込めない方法
- PyTorch以外でも読み込める方法