We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent edbcc12 commit d7bb7fbCopy full SHA for d7bb7fb
MachineLearning/DeepLearning/PyTorch/memo.md
@@ -1,11 +1,20 @@
1
## モデルの保存
2
```python
3
+save_path = './dir/file.pth'
4
torch.save(model.state_dict(), save_path) # 学習済みモデルパラメータ, 保存先パス
5
```
6
7
## モデルの読み込み
8
9
+load_path = './dir/file.pth'
10
model = ModelClass(*args, **kwargs)
-model.load_state_dict(torch.load(save_path))
11
+model.load_state_dict(torch.load(load_path))
12
+```
13
+
14
+### GPU上で保存されたパラメータをGPU上でロードする場合
15
+```python
16
17
+model = ModelClass(*args, **kwargs)
18
+model.load_state_dict(torch.load(load_path, map_location={'cuda:0': 'cpu'}))
19
20
0 commit comments