Skip to content

Commit d7bb7fb

Browse files
authored
Update
1 parent edbcc12 commit d7bb7fb

File tree

1 file changed

+10
-1
lines changed
  • MachineLearning/DeepLearning/PyTorch

1 file changed

+10
-1
lines changed
+10-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
11
## モデルの保存
22
```python
3+
save_path = './dir/file.pth'
34
torch.save(model.state_dict(), save_path) # 学習済みモデルパラメータ, 保存先パス
45
```
56

67
## モデルの読み込み
78
```python
9+
load_path = './dir/file.pth'
810
model = ModelClass(*args, **kwargs)
9-
model.load_state_dict(torch.load(save_path))
11+
model.load_state_dict(torch.load(load_path))
12+
```
13+
14+
### GPU上で保存されたパラメータをGPU上でロードする場合
15+
```python
16+
load_path = './dir/file.pth'
17+
model = ModelClass(*args, **kwargs)
18+
model.load_state_dict(torch.load(load_path, map_location={'cuda:0': 'cpu'}))
1019
```
1120

0 commit comments

Comments
 (0)