初心者データサイエンティストの備忘録

調べたことは全部ここに書いて自分の辞書を作る

【PyTorch】モデルの保存と読み込み

はじめに

 PyTorchのチュートリアルを勉強しています。本記事では「0. PyTorch入門」の「 7. モデルの保存・読み込み」を学んだ結果をまとめようと思います。

モデルの保存・読み込み方法

 PyTorchで作成したモデルの保存・読み込み方法を図1にまとめました。

図1:モデルの保存・読み込み方法

 具体的なコードは下記です。

モデルの重みだけを保存する

 下記のようにして、モデルの重みを保存します。なお、ファイルの拡張子は.pthが使われることが多いようです。

import torch
from torchvision import models

# 画像認識モデルVGG16の訓練済みモデルを読み込む(例)
model = models.vgg16(pretrained=True)

# モデルの重みのみを保存する
torch.save(model.state_dict(), 保存先のパス)

 読み込む際は、先にモデルのインスタンスを作成しておき、そこに保存された重みを設定します。

import torch
from torchvision import models

# モデルのインスタンスを作成
model = models.vgg16()

# モデルの重みの読み込み
model_weights = torch.load(モデルのパス)

# モデルの重みをモデルのインスタンスに設定
model.load_state_dict(model_weights)

モデルの重みだけでなく構成も保存するーPyTorchでしか読み込めない方法

 下記のようにして、モデルの重みや構成を保存します。こちらも、ファイルの拡張子は.pthが使われることが多いようです。

import torch

# モデル全体の保存
torch.save(model, 保存先のパス)

 下記のコードで読み込みます。

import torch
# モデル全体を読み込む
model = torch.load(モデルのパス)

モデルの重みだけでなく構成も保存するーPyTorch以外でも読み込める方法

 図1中のNo.1とNo.2の保存方法は、PyTorchでモデルを読み込み、推論することができます。しかし、No.3の保存方法では、PyTorchでモデルを読み込み、推論することができません。
 本節ではNo.3で保存されたモデルの読み込みと推論の仕方について説明します。
 まず、onnx形式で保存されたモデルを読み込み、モデルに問題ないこととモデルの構造を確認します。

import onnx

# onnx形式で保存されたモデルの読み込み
model = onnx.load(モデルのパス)

# モデルに問題ないことの確認
onnx.checker.check_model(model)

# モデルの構造の表示
print(onnx.helper.printable_graph(model.graph))

 モデルの構造は下記のように表示されました。

graph main_graph (
  %input.1[FLOAT, 1x3x224x224]
) initializers (
  %features.0.weight[FLOAT, 64x3x3x3]
  %features.0.bias[FLOAT, 64]
~(略)~

 次に推論を実行します。下記のコードで実行します。

import numpy as np
import onnxruntime

# 実行セクションを作成
sess = onnxruntime.InferenceSession(
    モデルのパス,
    providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
)

# データ作成
rng = np.random.default_rng()
input_image = rng.random((1, 3, 224, 224)).astype(np.float32) # 上記でグラフの構造を表示したときのgraph main_graph内に記載されているサイズを入力データのサイズとする

# 推論
output = sess.run(
    None,
    {"input.1":input_image} # 上記でグラフの構造を表示したときのgraph main_graph内に記載されている変数を入力変数とする
)

print("推論されたカテゴリ:", output[0].argmax(1))

まとめ

 本記事では、PyTorchで作成したモデルの保存方法と読み込み方についてまとめました。下記の3種類の方法を適切に使い分けられるようにしていきたいです。

  1. モデルの重みだけを保存する
  2. モデルの重みだけでなく構成も保存する
    1. PyTorchでしか読み込めない方法
    2. PyTorch以外でも読み込める方法