Skip to content

Commit ce62495

Browse files
authored
Add simple save model test (#1227)
* add save model test
1 parent 1ab34a4 commit ce62495

File tree

2 files changed

+103
-0
lines changed

2 files changed

+103
-0
lines changed

jiant/utils/torch_utils.py

+21
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,24 @@ def get_model_for_saving(model: nn.Module) -> nn.Module:
144144
return model.module
145145
else:
146146
return model
147+
148+
149+
def eq_state_dicts(state_dict_1, state_dict_2):
150+
"""Checks if the model weights in state_dict_1 and state_dict_2 are equal.
151+
152+
Args:
153+
state_dict_1 (dict): state_dict of a PyTorch model
154+
state_dict_2 (dict): state_dict of a PyTorch model
155+
156+
Requires:
157+
state_dict_1 and state_dict_2 to be from the same model
158+
159+
Returns:
160+
bool: Returns True if all model weights are equal in state_dict_1 and state_dict_2
161+
"""
162+
for key_item_1, key_item_2 in zip(state_dict_1.items(), state_dict_2.items()):
163+
if torch.equal(key_item_1[1], key_item_2[1]):
164+
pass
165+
else:
166+
return False
167+
return True

tests/proj/simple/test_runscript.py

+82
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import os
22
import pytest
3+
import torch
34

45
import jiant.utils.python.io as py_io
56
from jiant.proj.simple import runscript as run
67
import jiant.scripts.download_data.runscript as downloader
8+
import jiant.utils.torch_utils as torch_utils
79

810

911
@pytest.mark.parametrize("task_name", ["copa"])
@@ -29,3 +31,83 @@ def test_simple_runscript(tmpdir, task_name, model_type):
2931

3032
val_metrics = py_io.read_json(os.path.join(exp_dir, "runs", RUN_NAME, "val_metrics.json"))
3133
assert val_metrics["aggregated"] > 0
34+
35+
36+
@pytest.mark.gpu
37+
@pytest.mark.parametrize("task_name", ["copa"])
38+
@pytest.mark.parametrize("model_type", ["roberta-large"])
39+
def test_simple_runscript_save(tmpdir, task_name, model_type):
40+
run_name = f"{test_simple_runscript.__name__}_{task_name}_{model_type}_save"
41+
data_dir = str(tmpdir.mkdir("data"))
42+
exp_dir = str(tmpdir.mkdir("exp"))
43+
44+
downloader.download_data([task_name], data_dir)
45+
46+
args = run.RunConfiguration(
47+
run_name=run_name,
48+
exp_dir=exp_dir,
49+
data_dir=data_dir,
50+
model_type=model_type,
51+
tasks=task_name,
52+
max_steps=1,
53+
train_batch_size=32,
54+
do_save=True,
55+
eval_every_steps=10,
56+
learning_rate=0.01,
57+
num_train_epochs=5,
58+
)
59+
run.run_simple(args)
60+
61+
# check best_model and last_model exist
62+
assert os.path.exists(os.path.join(exp_dir, "runs", run_name, "best_model.p"))
63+
assert os.path.exists(os.path.join(exp_dir, "runs", run_name, "best_model.metadata.json"))
64+
assert os.path.exists(os.path.join(exp_dir, "runs", run_name, "last_model.p"))
65+
assert os.path.exists(os.path.join(exp_dir, "runs", run_name, "last_model.metadata.json"))
66+
67+
# assert best_model not equal to last_model
68+
best_model_weights = torch.load(
69+
os.path.join(exp_dir, "runs", run_name, "best_model.p"), map_location=torch.device("cpu")
70+
)
71+
last_model_weights = torch.load(
72+
os.path.join(exp_dir, "runs", run_name, "last_model.p"), map_location=torch.device("cpu")
73+
)
74+
assert not torch_utils.eq_state_dicts(best_model_weights, last_model_weights)
75+
76+
run_name = f"{test_simple_runscript.__name__}_{task_name}_{model_type}_save_best"
77+
args = run.RunConfiguration(
78+
run_name=run_name,
79+
exp_dir=exp_dir,
80+
data_dir=data_dir,
81+
model_type=model_type,
82+
tasks=task_name,
83+
max_steps=1,
84+
train_batch_size=16,
85+
do_save_best=True,
86+
)
87+
run.run_simple(args)
88+
89+
# check only best_model saved
90+
assert os.path.exists(os.path.join(exp_dir, "runs", run_name, "best_model.p"))
91+
assert os.path.exists(os.path.join(exp_dir, "runs", run_name, "best_model.metadata.json"))
92+
assert not os.path.exists(os.path.join(exp_dir, "runs", run_name, "last_model.p"))
93+
assert not os.path.exists(os.path.join(exp_dir, "runs", run_name, "last_model.metadata.json"))
94+
95+
# check output last model
96+
run_name = f"{test_simple_runscript.__name__}_{task_name}_{model_type}_save_last"
97+
args = run.RunConfiguration(
98+
run_name=run_name,
99+
exp_dir=exp_dir,
100+
data_dir=data_dir,
101+
model_type=model_type,
102+
tasks=task_name,
103+
max_steps=1,
104+
train_batch_size=16,
105+
do_save_last=True,
106+
)
107+
run.run_simple(args)
108+
109+
# check only last_model saved
110+
assert not os.path.exists(os.path.join(exp_dir, "runs", run_name, "best_model.p"))
111+
assert not os.path.exists(os.path.join(exp_dir, "runs", run_name, "best_model.metadata.json"))
112+
assert os.path.exists(os.path.join(exp_dir, "runs", run_name, "last_model.p"))
113+
assert os.path.exists(os.path.join(exp_dir, "runs", run_name, "last_model.metadata.json"))

0 commit comments

Comments
 (0)