-
Notifications
You must be signed in to change notification settings - Fork 13
/
test.py
37 lines (29 loc) · 1.31 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import argparse
import evaluation
import yaml
import torch
def main(opt, current_config):
model_checkpoint = opt.checkpoint
checkpoint = torch.load(model_checkpoint)
print('Checkpoint loaded from {}'.format(model_checkpoint))
loaded_config = checkpoint['config']
if opt.size == "1k":
fold5 = True
elif opt.size == "5k":
fold5 = False
else:
raise ValueError('Test split size not recognized!')
# Override some mandatory things in the configuration (paths)
loaded_config['dataset']['images-path'] = current_config['dataset']['images-path']
loaded_config['dataset']['data'] = current_config['dataset']['data']
loaded_config['image-model']['pre-extracted-features-root'] = current_config['image-model']['pre-extracted-features-root']
evaluation.evalrank(loaded_config, checkpoint, split="test", fold5=fold5)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('checkpoint', type=str, help="Checkpoint to load")
parser.add_argument('--size', type=str, choices=['1k', '5k'], default='1k')
parser.add_argument('--config', type=str, help="Which configuration to use. See into 'config' folder")
opt = parser.parse_args()
with open(opt.config, 'r') as ymlfile:
config = yaml.load(ymlfile)
main(opt, config)